Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShellJob: Raise when < or > are specified in arguments #28

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ print(results['stdout'].get_content())
```
which prints `string a`.

N.B.: one might be tempted to simply define the `arguments` as `['<', '{input}']`, but this won't work as the `<` symbol will be quoted and will be read as a literal command line argument, not as the redirection symbol.
This is why passing the `<` in the `arguments` input will result in a validation error.

### Defining output files
When the shell command is executed, AiiDA captures by default the content written to the stdout and stderr file descriptors.
The content is wrapped in a `SinglefileData` node and attached to the `ShellJob` with the `stdout` and `stderr` link labels, respectively.
Expand Down
22 changes: 21 additions & 1 deletion src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override]
super().define(spec)
spec.input_namespace('nodes', valid_type=Data, required=False, validator=cls.validate_nodes)
spec.input('filenames', valid_type=Dict, required=False, serializer=to_aiida_type)
spec.input('arguments', valid_type=List, required=False, serializer=to_aiida_type)
spec.input(
'arguments', valid_type=List, required=False, serializer=to_aiida_type, validator=cls.validate_arguments
)
spec.input('outputs', valid_type=List, required=False, serializer=to_aiida_type, validator=cls.validate_outputs)
spec.input(
'metadata.options.filename_stdin',
Expand Down Expand Up @@ -87,6 +89,24 @@ def validate_nodes(cls, value: t.Mapping[str, Data], _) -> str | None:
except Exception as exception: # pylint: disable=broad-except
return f'Casting `value` to `str` for `{key}` in `nodes` excepted: {exception}'

@classmethod
def validate_arguments(cls, value: List, _) -> str | None:
"""Validate the ``arguments`` input."""
if not value:
return None

elements = value.get_list()

if any(not isinstance(element, str) for element in elements):
return 'all elements of the `arguments` input should be strings'

if '<' in elements:
var = 'metadata.options.filename_stdin'
return f'`<` cannot be specified in the `arguments`; to redirect a file to stdin, use the `{var}` input.'

if '>' in elements:
return 'the symbol `>` cannot be specified in the `arguments`; stdout is automatically redirected.'

@classmethod
def validate_outputs(cls, value: List, _) -> str | None:
"""Validate the ``outputs`` input."""
Expand Down
14 changes: 14 additions & 0 deletions tests/calculations/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ def value_raises(self):
generate_calc_job('core.shell', {'code': generate_code(), 'nodes': nodes})


@pytest.mark.parametrize(
'arguments, message', (
(['string', 1], r'.*all elements of the `arguments` input should be strings'),
(['string', {input}], r'.*all elements of the `arguments` input should be strings'),
(['<', '{filename}'], r'`<` cannot be specified in the `arguments`.*'),
(['{filename}', '>'], r'the symbol `>` cannot be specified in the `arguments`.*'),
)
)
def test_validate_arguments(generate_calc_job, generate_code, arguments, message):
"""Test the validator for the ``arguments`` argument."""
with pytest.raises(ValueError, match=message):
generate_calc_job('core.shell', {'code': generate_code(), 'arguments': arguments})


def test_build_process_label(generate_calc_job, generate_code):
"""Test the :meth:`~aiida_shell.calculations.shell_job.ShellJob.build_process_label` method."""
computer = 'localhost'
Expand Down