Skip to content

Commit

Permalink
refactor template steps
Browse files Browse the repository at this point in the history
  • Loading branch information
patricktnast committed Oct 3, 2024
1 parent f60830d commit 7fe2577
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 108 deletions.
184 changes: 78 additions & 106 deletions src/easylink/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,9 @@ def config_key(self):
return "substeps"


class LoopStep(Step):
"""A LoopStep allows a user to loop a single step or a sequence of steps a user-configured number of times."""
class TemplateStep(Step):
"""TemplateStep is a class of step that transforms a template step into multiple instances according to
a specified transformation rule and the user-specified configuration."""

def __init__(
self,
Expand All @@ -469,30 +470,39 @@ def __init__(
input_slots: Iterable[InputSlot] = (),
output_slots: Iterable[OutputSlot] = (),
template_step: Step = None,
self_edges: Iterable[EdgeParams] = (),
) -> None:
super().__init__(step_name, name, input_slots, output_slots)
if not template_step or template_step.name != step_name:
raise NotImplementedError(
f"LoopStep {self.name} must be initialized with a single node with the same name."
f"{self.__class__} {self.name} must be initialized with a single node with the same name."
)
self.template_step = template_step
self.template_step.set_parent_step(self)
for edge in self_edges:
if not edge.source_node == edge.target_node == step_name:
raise NotImplementedError(
f"LoopStep {self.name} must be initialized with only self-loops as edges"
)
self.self_edges = self_edges

@property
def config_key(self):
return "iterate"
@abstractmethod
def config_key(self) -> str:
pass

@property
@abstractmethod
def node_prefix(self) -> str:
pass

@property
def num_repeats(self):
def num_repeats(self) -> int:
return len(self.config)

@abstractmethod
def _get_step_graph(self) -> StepGraph:
pass

@abstractmethod
def _get_slot_mappings(self) -> dict[str, list[SlotMapping]]:
"""Get the appropriate slot mappings based on the number of parallel copies
and the existing input and output slots."""
pass

def validate_step(
self, step_config: LayeredConfigTree, input_data_config: LayeredConfigTree
) -> dict[str, list[str]]:
Expand All @@ -504,18 +514,30 @@ def validate_step(
if not isinstance(sub_config, list):
return {
f"step {self.name}": [
"Loops must be formatted as a sequence in the pipeline configuration."
f"{self.node_prefix.capitalize()} instances must be formatted as a sequence in the pipeline configuration."
]
}

if len(sub_config) == 0:
return {f"step {self.name}": ["No loops configured under iterate key."]}
return {
f"step {self.name}": [
f"No {self.node_prefix} instances configured under '{self.config_key}' key."
]
}

errors = defaultdict(dict)
for i, loop in enumerate(sub_config):
loop_errors = self.template_step.validate_step(loop, input_data_config)
if loop_errors:
errors[f"step {self.name}"][f"loop {i+1}"] = loop_errors
for i, parallel_config in enumerate(sub_config):
parallel_errors = {}
input_data_file = parallel_config.get("input_data_file")
if input_data_file and not input_data_file in input_data_config:
parallel_errors["Input Data Key"] = [
f"Input data file '{input_data_file}' not found in input data configuration."
]
parallel_errors.update(
self.template_step.validate_step(parallel_config, input_data_config)
)
if parallel_errors:
errors[f"step {self.name}"][f"{self.node_prefix}_{i+1}"] = parallel_errors
return errors

def set_step_config(self, parent_config: LayeredConfigTree) -> None:
Expand All @@ -529,6 +551,40 @@ def set_step_config(self, parent_config: LayeredConfigTree) -> None:
self.slot_mappings = self._get_slot_mappings()
self.layer_state = CompositeState(self)

def _get_expanded_config(
self, step_config: LayeredConfigTree
) -> dict[str, LayeredConfigTree]:
"""Get the dictionary for the parallel graph based on the sequence
of sub-yamls."""
expanded_step_config = {}
for i, sub_config in enumerate(step_config):
expanded_step_config[f"{self.name}_{self.node_prefix}_{i+1}"] = sub_config
return LayeredConfigTree(expanded_step_config)


class LoopStep(TemplateStep):
"""A LoopStep allows a user to loop a single step or a sequence of steps a user-configured number of times."""

def __init__(
self,
step_name: str,
name: str = None,
input_slots: Iterable[InputSlot] = (),
output_slots: Iterable[OutputSlot] = (),
template_step: Step = None,
self_edges: Iterable[EdgeParams] = (),
) -> None:
super().__init__(step_name, name, input_slots, output_slots, template_step)
self.self_edges = self_edges

@property
def config_key(self):
return "iterate"

@property
def node_prefix(self):
return "loop"

def _get_step_graph(self) -> StepGraph:
"""Make N copies of the iterated graph and chain them together according
to the self edges."""
Expand Down Expand Up @@ -583,91 +639,17 @@ def _get_slot_mappings(self) -> dict:
]
return {"input": input_mappings, "output": output_mappings}

def _get_expanded_config(
self, step_config: LayeredConfigTree
) -> dict[str, LayeredConfigTree]:
"""Get the dictionary for the looped graph based on the sequence
of sub-yamls."""
expanded_config = {}
for i, sub_config in enumerate(step_config):
expanded_config[f"{self.name}_loop_{i+1}"] = sub_config
return LayeredConfigTree(expanded_config)


class ParallelStep(Step):
class ParallelStep(TemplateStep):
"""A ParallelStep allows a user to run a sequence of steps in parallel."""

def __init__(
self,
step_name: str,
name: str = None,
input_slots: Iterable[InputSlot] = (),
output_slots: Iterable[OutputSlot] = (),
template_step: Step = None,
) -> None:
super().__init__(step_name, name, input_slots, output_slots)
if not template_step or template_step.name != step_name:
raise NotImplementedError(
f"ParallelStep {self.name} must be initialized with a single node with the same name."
)
self.template_step = template_step
self.template_step.set_parent_step(self)

@property
def config_key(self):
return "parallel"

@property
def num_repeats(self):
return len(self.config)

def validate_step(
self, step_config: LayeredConfigTree, input_data_config: LayeredConfigTree
) -> dict[str, list[str]]:
if not self.config_key in step_config:
return self.template_step.validate_step(step_config, input_data_config)

sub_config = step_config[self.config_key]

if not isinstance(sub_config, list):
return {
f"step {self.name}": [
"Parallel instances must be formatted as a sequence in the pipeline configuration."
]
}

if len(sub_config) == 0:
return {
f"step {self.name}": [
"No parallel instances configured under 'parallel' key."
]
}

errors = defaultdict(dict)
for i, parallel_config in enumerate(sub_config):
parallel_errors = {}
input_data_file = parallel_config.get("input_data_file")
if input_data_file and not input_data_file in input_data_config:
parallel_errors["Input Data Key"] = [
f"Input data file '{input_data_file}' not found in input data configuration."
]
parallel_errors.update(
self.template_step.validate_step(parallel_config, input_data_config)
)
if parallel_errors:
errors[f"step {self.name}"][f"parallel_split_{i+1}"] = parallel_errors
return errors

def set_step_config(self, parent_config: LayeredConfigTree) -> None:
step_config = parent_config[self.name]
if not self.config_key in step_config:
self._config = step_config
self.layer_state = LeafState(self)
else:
self._config = self._get_expanded_config(step_config[self.config_key])
self.step_graph = self._get_step_graph()
self.slot_mappings = self._get_slot_mappings()
self.layer_state = CompositeState(self)
def node_prefix(self):
return "parallel_split"

def _get_step_graph(self) -> StepGraph:
"""Make N copies of the template step that are independent and contain the same edges as the
Expand Down Expand Up @@ -696,13 +678,3 @@ def _get_slot_mappings(self) -> dict[str, list[SlotMapping]]:
for slot in self.output_slots
]
return {"input": input_mappings, "output": output_mappings}

def _get_expanded_config(
self, step_config: LayeredConfigTree
) -> dict[str, LayeredConfigTree]:
"""Get the dictionary for the parallel graph based on the sequence
of sub-yamls."""
expanded_step_config = {}
for i, sub_config in enumerate(step_config):
expanded_step_config[f"{self.name}_parallel_split_{i+1}"] = sub_config
return LayeredConfigTree(expanded_step_config)
4 changes: 2 additions & 2 deletions tests/unit/test_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_batch_validation():
{
PIPELINE_ERRORS_KEY: {
"development": {
"step step_3": ["No loops configured under iterate key."],
"step step_3": ["No loop instances configured under iterate key."],
},
},
},
Expand All @@ -125,7 +125,7 @@ def test_batch_validation():
PIPELINE_ERRORS_KEY: {
"development": {
"step step_3": [
"Loops must be formatted as a sequence in the pipeline configuration."
"Loop instances must be formatted as a sequence in the pipeline configuration."
],
},
},
Expand Down

0 comments on commit 7fe2577

Please sign in to comment.