Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ShellCode data plugin #20

Merged
merged 1 commit into from
Nov 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ tests = [
[project.entry-points.'aiida.calculations']
'core.shell' = 'aiida_shell.calculations.shell:ShellJob'

[project.entry-points.'aiida.data']
'core.code.installed.shell' = 'aiida_shell.data.code:ShellCode'

[project.entry-points.'aiida.parsers']
'core.shell' = 'aiida_shell.parsers.shell:ShellParser'

Expand Down Expand Up @@ -109,6 +112,7 @@ disable = [
'duplicate-code',
'import-outside-toplevel',
'inconsistent-return-statements',
'too-many-ancestors',
]

[tool.yapf]
Expand Down
1 change: 1 addition & 0 deletions src/aiida_shell/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""AiiDA plugin that makes running shell commands easy."""
from .calculations import ShellJob
from .data import ShellCode
from .engine import launch_shell_job
from .parsers import ShellParser

Expand Down
3 changes: 3 additions & 0 deletions src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def validate_nodes(cls, value: t.Mapping[str, Data], _) -> str | None:
@classmethod
def validate_outputs(cls, value: List, _) -> str | None:
"""Validate the ``outputs`` input."""
if not value:
return None

for reserved in [cls.FILENAME_STATUS, cls.FILENAME_STDERR, cls.FILENAME_STDOUT]:
if reserved in value:
return f'`{reserved}` is a reserved output filename and cannot be used in `outputs`.'
Expand Down
5 changes: 5 additions & 0 deletions src/aiida_shell/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
"""Module for :mod:`aiida_shell.data`."""
from .code import ShellCode

__all__ = ('ShellCode',)
39 changes: 39 additions & 0 deletions src/aiida_shell/data/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""Code that represents a shell command."""
from __future__ import annotations

from aiida.orm import InstalledCode

__all__ = ('ShellCode',)


class ShellCode(InstalledCode):
"""Code that represents a shell command.

This code type is automatically generated by the :func:`~aiida_shell.engine.launch.launch_shell_job` function and is
a subclass of :class:`~aiida.orm.nodes.data.code.installed.InstalledCode`. It can therefore be used for any other
calculation job as well.
"""

def __init__(self, *args, default_calc_job_plugin: str = 'core.shell', **kwargs) -> None:
"""Construct a new instance."""
self.validate_default_calc_job_plugin(default_calc_job_plugin)
super().__init__(*args, default_calc_job_plugin=default_calc_job_plugin, **kwargs)

@staticmethod
def validate_default_calc_job_plugin(default_calc_job_plugin: str) -> None:
"""Validate the default calculation job plugin.

The ``ShellCode`` should only be used with the ``core.shell`` calculation job entry point.

:raises ValueError: If ``default_calc_job_plugin`` is not ``core.shell``.
"""
if default_calc_job_plugin != 'core.shell':
raise ValueError(f'`default_calc_job_plugin` has to be `core.shell`, but got: {default_calc_job_plugin}')

@classmethod
def _get_cli_options(cls) -> dict:
"""Return the CLI options that would allow to create an instance of this class."""
options = super()._get_cli_options()
options['default_calc_job_plugin']['default'] = 'core.shell'
return options
61 changes: 37 additions & 24 deletions src/aiida_shell/engine/launchers/shell_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

from aiida.common import exceptions
from aiida.engine import launch
from aiida.orm import Code, Computer, Data, ProcessNode, SinglefileData, load_code, load_computer
from aiida.orm import AbstractCode, Computer, Data, ProcessNode, SinglefileData, load_code, load_computer

from aiida_shell.calculations.shell import ShellJob
from aiida_shell import ShellCode, ShellJob

__all__ = ('launch_shell_job',)

LOGGER = logging.getLogger('aiida_shell')


def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
def launch_shell_job( # pylint: disable=too-many-arguments
command: str,
nodes: dict[str, Data] | None = None,
filenames: dict[str, str] | None = None,
Expand All @@ -40,7 +40,36 @@ def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
:raises ValueError: If the absolute path of the command on the computer could not be determined.
:returns: The tuple of results dictionary and ``ProcessNode``, or just the ``ProcessNode`` if ``submit=True``.
"""
computer = prepare_computer((metadata or {}).get('options', {}).pop('computer', None))
computer = (metadata or {}).get('options', {}).pop('computer', None)
code = prepare_code(command, computer)

inputs = {
'code': code,
'nodes': convert_nodes_single_file_data(nodes or {}),
'filenames': filenames,
'arguments': arguments,
'outputs': outputs,
'metadata': metadata or {},
}

if submit:
return launch.submit(ShellJob, **inputs)

results, node = launch.run_get_node(ShellJob, **inputs)

return {label: node for label, node in results.items() if isinstance(node, SinglefileData)}, node


def prepare_code(command: str, computer: Computer | None = None) -> AbstractCode:
"""Prepare a code for the given command and computer.

This will automatically prepare the computer

:param command: The command that the code should represent. Can be the relative executable name or absolute path.
:param computer: The computer on which the command should be run. If not defined the localhost will be used.
:return: A :class:`aiida.orm.nodes.code.abstract.AbstractCode` instance.
"""
computer = prepare_computer(computer)

with computer.get_transport() as transport:
status, stdout, stderr = transport.exec_command_wait(f'which {command}')
Expand All @@ -52,30 +81,14 @@ def launch_shell_job( # pylint: disable=too-many-arguments,too-many-locals
code_label = f'{command}@{computer.label}'

try:
code = load_code(code_label)
code: AbstractCode = load_code(code_label)
except exceptions.NotExistent:
LOGGER.info('No code exists yet for `%s`, creating it now.', code_label)
code = Code( # type: ignore[assignment]
label=command,
remote_computer_exec=(computer, executable),
input_plugin_name='core.shell'
code = ShellCode( # type: ignore[assignment]
label=command, computer=computer, filepath_executable=executable, default_calc_job_plugin='core.shell'
).store()

inputs = {
'code': code,
'nodes': convert_nodes_single_file_data(nodes or {}),
'filenames': filenames or {},
'arguments': arguments or [],
'outputs': outputs or [],
'metadata': metadata or {},
}

if submit:
return launch.submit(ShellJob, **inputs)

results, node = launch.run_get_node(ShellJob, **inputs)

return {label: node for label, node in results.items() if isinstance(node, SinglefileData)}, node
return code


def prepare_computer(computer: Computer | None = None) -> Computer:
Expand Down
18 changes: 12 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from aiida.common.links import LinkType
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from aiida.orm import CalcJobNode, Code, Computer, FolderData
from aiida.orm import CalcJobNode, Computer, FolderData
from aiida.plugins import CalculationFactory, ParserFactory
import pytest

from aiida_shell import ShellCode

pytest_plugins = ['aiida.manage.tests.pytest_fixtures'] # pylint: disable=invalid-name


Expand Down Expand Up @@ -135,10 +137,10 @@ def factory(label='localhost', hostname='localhost', scheduler_type='core.direct

@pytest.fixture
def generate_code(generate_computer):
"""Return a :class:`aiida.orm.Code` instance, either already existing or created."""
"""Return a :class:`aiida_shell.data.code.ShellCode` instance, either already existing or created."""

def factory(command='/bin/true', computer_label='localhost', label=None, entry_point_name='core.shell'):
"""Return a :class:`aiida.orm.Code` instance, either already existing or created."""
"""Return a :class:`aiida_shell.data.code.ShellCode` instance, either already existing or created."""
label = label or str(uuid.uuid4())
computer = generate_computer(computer_label)

Expand All @@ -151,10 +153,14 @@ def factory(command='/bin/true', computer_label='localhost', label=None, entry_p

try:
filters = {'label': label, 'attributes.input_plugin_name': entry_point_name}
return Code.collection.get(**filters)
return ShellCode.collection.get(**filters)
except exceptions.NotExistent:
code = Code(label=label, input_plugin_name=entry_point_name, remote_computer_exec=[computer, executable])
return code.store()
return ShellCode(
label=label,
computer=computer,
filepath_executable=executable,
default_calc_job_plugin=entry_point_name
).store()

return factory

Expand Down
40 changes: 40 additions & 0 deletions tests/data/test_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""Tests for :mod:`aiida_shell.data.code`."""
import pytest

from aiida_shell.data.code import ShellCode


def test_constructor(generate_computer):
"""Test initializing an instance."""
code = ShellCode(
label='bash',
computer=generate_computer(),
filepath_executable='/bin/bash',
default_calc_job_plugin='core.shell',
)
assert isinstance(code, ShellCode)


def test_constructor_invalid(generate_computer):
"""Test the constructor raises if ``default_calc_job_plugin`` is not ``core.shell``."""
with pytest.raises(ValueError, match=r'`default_calc_job_plugin` has to be `core.shell`, but got: .*'):
ShellCode(
label='bash',
computer=generate_computer(),
filepath_executable='/bin/bash',
default_calc_job_plugin='core.arithmetic.add',
)


@pytest.mark.parametrize(('value', 'exception'), (
('core.shell', None),
('core.arithmetic.add', r'`default_calc_job_plugin` has to be `core.shell`, but got: .*'),
))
def test_validate_default_calc_job_plugin(value, exception):
"""Test the constructor raises if ``default_calc_job_plugin`` is not ``core.shell``."""
if exception:
with pytest.raises(ValueError, match=exception):
ShellCode.validate_default_calc_job_plugin(value)
else:
assert ShellCode.validate_default_calc_job_plugin(value) is None