Skip to content

Commit

Permalink
refactor template step
Browse files Browse the repository at this point in the history
  • Loading branch information
patricktnast committed Oct 3, 2024
1 parent 8985f46 commit a0e030e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 61 deletions.
18 changes: 0 additions & 18 deletions src/easylink/pipeline_schema_constants/development.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
NODES = [
InputStep(),
ParallelStep(
step_name="step_1",
input_slots=[
InputSlot(
name="step_1_main_input",
env_var="DUMMY_CONTAINER_MAIN_INPUT_FILE_PATHS",
validator=validate_input_file_dummy,
),
],
output_slots=[OutputSlot("step_1_main_output")],
template_step=GenericStep(
step_name="step_1",
input_slots=[
Expand All @@ -51,15 +42,6 @@
output_slots=[OutputSlot("step_2_main_output")],
),
LoopStep(
step_name="step_3",
input_slots=[
InputSlot(
name="step_3_main_input",
env_var="DUMMY_CONTAINER_MAIN_INPUT_FILE_PATHS",
validator=validate_input_file_dummy,
),
],
output_slots=[OutputSlot("step_3_main_output")],
template_step=GenericStep(
step_name="step_3",
input_slots=[
Expand Down
28 changes: 8 additions & 20 deletions src/easylink/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,25 +455,17 @@ class TemplateStep(Step):

def __init__(
self,
step_name: str,
name: str = None,
input_slots: Iterable[InputSlot] = (),
output_slots: Iterable[OutputSlot] = (),
template_step: Step = None,
template_step: Step,
) -> None:
super().__init__(step_name, name, input_slots, output_slots)
if not template_step or template_step.name != step_name:
raise NotImplementedError(
f"{self.__class__} {self.name} must be initialized with a single node with the same name."
)
super().__init__(
template_step.step_name,
template_step.name,
template_step.input_slots.values(),
template_step.output_slots.values(),
)
self.template_step = template_step
self.template_step.set_parent_step(self)

@property
@abstractmethod
def config_key(self) -> str:
pass

@property
@abstractmethod
def node_prefix(self) -> str:
Expand Down Expand Up @@ -556,14 +548,10 @@ class LoopStep(TemplateStep):

def __init__(
self,
step_name: str,
name: str = None,
input_slots: Iterable[InputSlot] = (),
output_slots: Iterable[OutputSlot] = (),
template_step: Step = None,
self_edges: Iterable[EdgeParams] = (),
) -> None:
super().__init__(step_name, name, input_slots, output_slots, template_step)
super().__init__(template_step)
self.self_edges = self_edges

@property
Expand Down
23 changes: 0 additions & 23 deletions tests/unit/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,6 @@ def test_hierarchical_step_update_implementation_graph(
@pytest.fixture
def loop_step_params() -> Dict[str, Any]:
return {
"step_name": "step_3",
"input_slots": [
InputSlot(
name="step_3_main_input",
env_var="DUMMY_CONTAINER_MAIN_INPUT_FILE_PATHS",
validator=validate_input_file_dummy,
),
InputSlot(
name="step_3_secondary_input",
env_var="DUMMY_CONTAINER_SECONDARY_INPUT_FILE_PATHS",
validator=validate_input_file_dummy,
),
],
"output_slots": [OutputSlot("step_3_main_output")],
"template_step": HierarchicalStep(
"step_3",
input_slots=[
Expand Down Expand Up @@ -544,15 +530,6 @@ def test_loop_update_implementation_graph(
@pytest.fixture
def parallel_step_params() -> Dict[str, Any]:
return {
"step_name": "step_1",
"input_slots": [
InputSlot(
"step_1_main_input",
"DUMMY_CONTAINER_MAIN_INPUT_FILE_PATHS",
validate_input_file_dummy,
)
],
"output_slots": [OutputSlot("step_1_main_output")],
"template_step": HierarchicalStep(
"step_1",
input_slots=[
Expand Down

0 comments on commit a0e030e

Please sign in to comment.