Skip to content

Commit

Permalink
Add the ShellCode data plugin (#20)
Browse files Browse the repository at this point in the history
The `ShellCode` data plugin is a subclass of the `InstalledCode` plugin
from `aiida-core`. The new plugin is used by the `launch_shell_job`
utility function whenever a code has to be created on the fly. This will
serve to be able to query for these kinds of codes and to distinguish
them from codes that are typically setup manually.
  • Loading branch information
sphuber authored Nov 6, 2022
1 parent 506fe91 commit a185219
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 30 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ tests = [
[project.entry-points.'aiida.calculations']
'core.shell' = 'aiida_shell.calculations.shell:ShellJob'

[project.entry-points.'aiida.data']
'core.code.installed.shell' = 'aiida_shell.data.code:ShellCode'

[project.entry-points.'aiida.parsers']
'core.shell' = 'aiida_shell.parsers.shell:ShellParser'

Expand Down Expand Up @@ -109,6 +112,7 @@ disable = [
'duplicate-code',
'import-outside-toplevel',
'inconsistent-return-statements',
'too-many-ancestors',
]

[tool.yapf]
Expand Down
1 change: 1 addition & 0 deletions src/aiida_shell/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""AiiDA plugin that makes running shell commands easy."""
from .calculations import ShellJob
from .data import ShellCode
from .engine import launch_shell_job
from .parsers import ShellParser

Expand Down
3 changes: 3 additions & 0 deletions src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def validate_nodes(cls, value: t.Mapping[str, Data], _) -> str | None:
@classmethod
def validate_outputs(cls, value: List, _) -> str | None:
"""Validate the ``outputs`` input."""
if not value:
return None

for reserved in [cls.FILENAME_STATUS, cls.FILENAME_STDERR, cls.FILENAME_STDOUT]:
if reserved in value:
return f'`{reserved}` is a reserved output filename and cannot be used in `outputs`.'
Expand Down
5 changes: 5 additions & 0 deletions src/aiida_shell/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
"""Module for :mod:`aiida_shell.data`."""
from .code import ShellCode

__all__ = ('ShellCode',)
39 changes: 39 additions & 0 deletions src/aiida_shell/data/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""Code that represents a shell command."""
from __future__ import annotations

from aiida.orm import InstalledCode

__all__ = ('ShellCode',)


class ShellCode(InstalledCode):
"""Code that represents a shell command.
This code type is automatically generated by the :func:`~aiida_shell.engine.launch.launch_shell_job` function and is
a subclass of :class:`~aiida.orm.nodes.data.code.installed.InstalledCode`. It can therefore be used for any other
calculation job as well.
"""

def __init__(self, *args, default_calc_job_plugin: str = 'core.shell', **kwargs) -> None:
"""Construct a new instance."""
self.validate_default_calc_job_plugin(default_calc_job_plugin)
super().__init__(*args, default_calc_job_plugin=default_calc_job_plugin, **kwargs)

@staticmethod
def validate_default_calc_job_plugin(default_calc_job_plugin: str) -> None:
"""Validate the default calculation job plugin.
The ``ShellCode`` should only be used with the ``core.shell`` calculation job entry point.
:raises ValueError: If ``default_calc_job_plugin`` is not ``core.shell``.
"""
if default_calc_job_plugin != 'core.shell':
raise ValueError(f'`default_calc_job_plugin` has to be `core.shell`, but got: {default_calc_job_plugin}')

@classmethod
def _get_cli_options(cls) -> dict:
"""Return the CLI options that would allow to create an instance of this class."""
options = super()._get_cli_options()
options['default_calc_job_plugin']['default'] = 'core.shell'
return options
61 changes: 37 additions & 24 deletions src/aiida_shell/engine/launchers/shell_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

from aiida.common import exceptions
from aiida.engine import launch
from aiida.orm import Code, Computer, Data, ProcessNode, SinglefileData, load_code, load_computer
from aiida.orm import AbstractCode, Computer, Data, ProcessNode, SinglefileData, load_code, load_computer

from aiida_shell.calculations.shell import ShellJob
from aiida_shell import ShellCode, ShellJob

__all__ = ('launch_shell_job',)

LOGGER = logging.getLogger('aiida_shell')


def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
def launch_shell_job( # pylint: disable=too-many-arguments
command: str,
nodes: dict[str, Data] | None = None,
filenames: dict[str, str] | None = None,
Expand All @@ -40,7 +40,36 @@ def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
:raises ValueError: If the absolute path of the command on the computer could not be determined.
:returns: The tuple of results dictionary and ``ProcessNode``, or just the ``ProcessNode`` if ``submit=True``.
"""
computer = prepare_computer((metadata or {}).get('options', {}).pop('computer', None))
computer = (metadata or {}).get('options', {}).pop('computer', None)
code = prepare_code(command, computer)

inputs = {
'code': code,
'nodes': convert_nodes_single_file_data(nodes or {}),
'filenames': filenames,
'arguments': arguments,
'outputs': outputs,
'metadata': metadata or {},
}

if submit:
return launch.submit(ShellJob, **inputs)

results, node = launch.run_get_node(ShellJob, **inputs)

return {label: node for label, node in results.items() if isinstance(node, SinglefileData)}, node


def prepare_code(command: str, computer: Computer | None = None) -> AbstractCode:
"""Prepare a code for the given command and computer.
This will automatically prepare the computer
:param command: The command that the code should represent. Can be the relative executable name or absolute path.
:param computer: The computer on which the command should be run. If not defined the localhost will be used.
:return: A :class:`aiida.orm.nodes.code.abstract.AbstractCode` instance.
"""
computer = prepare_computer(computer)

with computer.get_transport() as transport:
status, stdout, stderr = transport.exec_command_wait(f'which {command}')
Expand All @@ -52,30 +81,14 @@ def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
code_label = f'{command}@{computer.label}'

try:
code = load_code(code_label)
code: AbstractCode = load_code(code_label)
except exceptions.NotExistent:
LOGGER.info('No code exists yet for `%s`, creating it now.', code_label)
code = Code( # type: ignore[assignment]
label=command,
remote_computer_exec=(computer, executable),
input_plugin_name='core.shell'
code = ShellCode( # type: ignore[assignment]
label=command, computer=computer, filepath_executable=executable, default_calc_job_plugin='core.shell'
).store()

inputs = {
'code': code,
'nodes': convert_nodes_single_file_data(nodes or {}),
'filenames': filenames or {},
'arguments': arguments or [],
'outputs': outputs or [],
'metadata': metadata or {},
}

if submit:
return launch.submit(ShellJob, **inputs)

results, node = launch.run_get_node(ShellJob, **inputs)

return {label: node for label, node in results.items() if isinstance(node, SinglefileData)}, node
return code


def prepare_computer(computer: Computer | None = None) -> Computer:
Expand Down
18 changes: 12 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from aiida.common.links import LinkType
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from aiida.orm import CalcJobNode, Code, Computer, FolderData
from aiida.orm import CalcJobNode, Computer, FolderData
from aiida.plugins import CalculationFactory, ParserFactory
import pytest

from aiida_shell import ShellCode

pytest_plugins = ['aiida.manage.tests.pytest_fixtures'] # pylint: disable=invalid-name


Expand Down Expand Up @@ -135,10 +137,10 @@ def factory(label='localhost', hostname='localhost', scheduler_type='core.direct

@pytest.fixture
def generate_code(generate_computer):
"""Return a :class:`aiida.orm.Code` instance, either already existing or created."""
"""Return a :class:`aiida_shell.data.code.ShellCode` instance, either already existing or created."""

def factory(command='/bin/true', computer_label='localhost', label=None, entry_point_name='core.shell'):
"""Return a :class:`aiida.orm.Code` instance, either already existing or created."""
"""Return a :class:`aiida_shell.data.code.ShellCode` instance, either already existing or created."""
label = label or str(uuid.uuid4())
computer = generate_computer(computer_label)

Expand All @@ -151,10 +153,14 @@ def factory(command='/bin/true', computer_label='localhost', label=None, entry_p

try:
filters = {'label': label, 'attributes.input_plugin_name': entry_point_name}
return Code.collection.get(**filters)
return ShellCode.collection.get(**filters)
except exceptions.NotExistent:
code = Code(label=label, input_plugin_name=entry_point_name, remote_computer_exec=[computer, executable])
return code.store()
return ShellCode(
label=label,
computer=computer,
filepath_executable=executable,
default_calc_job_plugin=entry_point_name
).store()

return factory

Expand Down
40 changes: 40 additions & 0 deletions tests/data/test_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""Tests for :mod:`aiida_shell.data.code`."""
import pytest

from aiida_shell.data.code import ShellCode


def test_constructor(generate_computer):
"""Test initializing an instance."""
code = ShellCode(
label='bash',
computer=generate_computer(),
filepath_executable='/bin/bash',
default_calc_job_plugin='core.shell',
)
assert isinstance(code, ShellCode)


def test_constructor_invalid(generate_computer):
"""Test the constructor raises if ``default_calc_job_plugin`` is not ``core.shell``."""
with pytest.raises(ValueError, match=r'`default_calc_job_plugin` has to be `core.shell`, but got: .*'):
ShellCode(
label='bash',
computer=generate_computer(),
filepath_executable='/bin/bash',
default_calc_job_plugin='core.arithmetic.add',
)


@pytest.mark.parametrize(('value', 'exception'), (
('core.shell', None),
('core.arithmetic.add', r'`default_calc_job_plugin` has to be `core.shell`, but got: .*'),
))
def test_validate_default_calc_job_plugin(value, exception):
"""Test the constructor raises if ``default_calc_job_plugin`` is not ``core.shell``."""
if exception:
with pytest.raises(ValueError, match=exception):
ShellCode.validate_default_calc_job_plugin(value)
else:
assert ShellCode.validate_default_calc_job_plugin(value) is None

0 comments on commit a185219

Please sign in to comment.