Skip to content

Commit

Permalink
Merge pull request #1529 from pyiron/exe_auto_job_name
Browse files Browse the repository at this point in the history
wrap_executable(): no longer require job_name - enable automatic renaming
  • Loading branch information
jan-janssen committed Jul 15, 2024
2 parents 82fecbf + ad77e29 commit 8d45d6a
Show file tree
Hide file tree
Showing 4 changed files with 644 additions and 17 deletions.
81 changes: 68 additions & 13 deletions pyiron_base/project/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import stat
from typing import TYPE_CHECKING, Dict, Generator, Literal, Union

import cloudpickle
import numpy as np
import pandas
from pyiron_snippets.deprecate import deprecate
Expand Down Expand Up @@ -47,7 +48,7 @@
from pyiron_base.jobs.job.util import _get_safe_job_name, _special_symbol_replacements
from pyiron_base.project.archiving import export_archive, import_archive
from pyiron_base.project.data import ProjectData
from pyiron_base.project.delayed import DelayedObject
from pyiron_base.project.delayed import DelayedObject, get_hash
from pyiron_base.project.external import Notebook
from pyiron_base.project.jobloader import JobInspector, JobLoader
from pyiron_base.project.path import ProjectPath
Expand Down Expand Up @@ -367,14 +368,15 @@ def create_job_class(

def wrap_executable(
self,
job_name,
executable_str,
job_name=None,
write_input_funct=None,
collect_output_funct=None,
input_dict=None,
conda_environment_path=None,
conda_environment_name=None,
input_file_lst=None,
automatically_rename=False,
execute_job=False,
delayed=False,
output_file_lst=[],
Expand All @@ -384,15 +386,21 @@ def wrap_executable(
Wrap any executable into a pyiron job object using the ExecutableContainerJob.
Args:
job_name (str): name of the new job object
executable_str (str): call to an external executable
job_name (str): name of the new job object
write_input_funct (callable): The write input function write_input(input_dict, working_directory)
collect_output_funct (callable): The collect output function collect_output(working_directory)
input_dict (dict): Default input for the newly created job class
conda_environment_path (str): path of the conda environment
conda_environment_name (str): name of the conda environment
input_file_lst (list): list of files to be copied to the working directory before executing it\
execute_job (boolean): automatically call run() on the job object - default false
automatically_rename (bool): Whether to automatically rename the job at
save-time to append a string based on the input values. (Default is
False.)
delayed (bool): delayed execution
output_file_lst (list):
output_key_lst (list):
Example:
Expand Down Expand Up @@ -422,43 +430,86 @@ def wrap_executable(
pyiron_base.jobs.flex.ExecutableContainerJob: pyiron job object
"""

def create_exeuctable_job(
def generate_job_hash(
project,
input_internal_dict,
executable_internal_str,
internal_file_lst,
internal_job_name=None,
):
job = create_job_factory(
write_input_funct=write_input_funct,
collect_output_funct=collect_output_funct,
default_input_dict=input_internal_dict,
executable_str=executable_internal_str,
)(project=project, job_name=internal_job_name)
if internal_file_lst is not None and len(internal_file_lst) > 0:
for file in internal_file_lst:
job.restart_file_list.append(file)
return (
internal_job_name
+ "_"
+ get_hash(
binary=cloudpickle.dumps(
{
"write_input": write_input_funct,
"collect_output": collect_output_funct,
"kwargs": job.calculate_kwargs,
}
)
)
)

def create_executable_job(
project,
input_internal_dict,
executable_internal_str,
internal_file_lst,
execute_job=True,
internal_job_name=None,
internal_execute_job=True,
internal_auto_rename=False,
):
if internal_job_name is None:
internal_job_name = "exe"
internal_auto_rename = True
if internal_auto_rename:
internal_job_name = generate_job_hash(
project=project,
input_internal_dict=input_internal_dict,
executable_internal_str=executable_internal_str,
internal_file_lst=internal_file_lst,
internal_job_name=internal_job_name,
)
job_id = get_job_id(
database=project.db,
sql_query=project.sql_query,
user=project.user,
project_path=project.project_path,
job_specifier=job_name,
job_specifier=internal_job_name,
)
if job_id is None:
job = create_job_factory(
write_input_funct=write_input_funct,
collect_output_funct=collect_output_funct,
default_input_dict=input_internal_dict,
executable_str=executable_internal_str,
)(project=project, job_name=job_name)
)(project=project, job_name=internal_job_name)
else:
return project.load(job_specifier=job_name)
return project.load(job_specifier=job_id)
if conda_environment_path is not None:
job.server.conda_environment_path = conda_environment_path
elif conda_environment_name is not None:
job.server.conda_environment_name = conda_environment_name
if internal_file_lst is not None and len(internal_file_lst) > 0:
for file in internal_file_lst:
job.restart_file_list.append(file)
if execute_job:
if internal_execute_job:
job.run()
return job

if delayed:
return DelayedObject(
function=create_exeuctable_job,
function=create_executable_job,
output_key=None,
output_file=None,
output_file_lst=[f.replace(".", "_") for f in output_file_lst],
Expand All @@ -467,15 +518,19 @@ def create_exeuctable_job(
input_internal_dict=input_dict,
executable_internal_str=executable_str,
internal_file_lst=input_file_lst,
execute_job=True,
internal_job_name=job_name,
internal_auto_rename=automatically_rename,
internal_execute_job=True,
)
else:
return create_exeuctable_job(
return create_executable_job(
project=self,
input_internal_dict=input_dict,
executable_internal_str=executable_str,
internal_file_lst=input_file_lst,
execute_job=execute_job,
internal_job_name=job_name,
internal_auto_rename=automatically_rename,
internal_execute_job=execute_job,
)

def create_job(self, job_type, job_name, delete_existing_job=False):
Expand Down
70 changes: 68 additions & 2 deletions tests/unit/flex/test_executablecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,33 @@ def test_series_of_jobs(self):
)
self.assertEqual(w.output.result, 7)

@unittest.skipIf(
os.name == "nt",
"shell script test is skipped on windows.",
)
def test_series_of_jobs(self):
z = self.project.wrap_executable(
executable_str="x=$(cat x.txt); y=$(cat y.txt); echo $(($x + $y)) > result.txt",
write_input_funct=write_input_series,
collect_output_funct=collect_output_series,
input_dict={"x": 1, "y": 2},
conda_environment_path=None,
conda_environment_name=None,
input_file_lst=None,
execute_job=True,
)
w = self.project.wrap_executable(
executable_str="x=$(cat x.txt); y=$(cat y.txt); z=$(cat result.txt); echo $(($x + $y + $z)) > result.txt",
write_input_funct=write_input_series,
collect_output_funct=collect_output_series,
input_dict={"x": 1, "y": z.output.result},
conda_environment_path=None,
conda_environment_name=None,
input_file_lst=[z.files.result_txt],
execute_job=True,
)
self.assertEqual(w.output.result, 7)

@unittest.skipIf(
os.name == "nt",
"delayed shell script test is skipped on windows.",
Expand Down Expand Up @@ -264,8 +291,47 @@ def test_delayed_series_of_jobs(self):
)
self.assertEqual(w.output.result.pull(), 7)
nodes_dict, edges_lst = w.get_graph()
self.assertEqual(len(nodes_dict), 12)
self.assertEqual(len(edges_lst), 18)
self.assertEqual(len(nodes_dict), 15)
self.assertEqual(len(edges_lst), 24)
job_w = w.pull()
job_z = z.pull()
self.assertEqual(job_w.output.result, 7)
self.project.remove_job(job_z.job_name)
self.project.remove_job(job_w.job_name)

@unittest.skipIf(
os.name == "nt",
"delayed shell script test is skipped on windows.",
)
def test_delayed_series_of_jobs_without_job_name(self):
z = self.project.wrap_executable(
executable_str="x=$(cat x.txt); y=$(cat y.txt); echo $(($x + $y)) > result.txt",
write_input_funct=write_input_series,
collect_output_funct=collect_output_series,
input_dict={"x": 1, "y": 2},
conda_environment_path=None,
conda_environment_name=None,
input_file_lst=None,
delayed=True,
output_file_lst=["result.txt"],
output_key_lst=["result"],
)
w = self.project.wrap_executable(
executable_str="x=$(cat x.txt); y=$(cat y.txt); z=$(cat result.txt); echo $(($x + $y + $z)) > result.txt",
write_input_funct=write_input_series,
collect_output_funct=collect_output_series,
input_dict={"x": 1, "y": z.output.result},
conda_environment_path=None,
conda_environment_name=None,
input_file_lst=[z.files.result_txt],
delayed=True,
output_file_lst=["result.txt"],
output_key_lst=["result"],
)
self.assertEqual(w.output.result.pull(), 7)
nodes_dict, edges_lst = w.get_graph()
self.assertEqual(len(nodes_dict), 14)
self.assertEqual(len(edges_lst), 24)
job_w = w.pull()
job_z = z.pull()
self.assertEqual(job_w.output.result, 7)
Expand Down
261 changes: 260 additions & 1 deletion tests/usecases/ADIS/notebook.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 8d45d6a

Please sign in to comment.