diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index aa6a85c..42693b0 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -112,6 +112,7 @@ def _set_indexes_by_name_map(self): self._indexes_by_name_map[node_name].append(node_id) def get_by_name(self, name: str) -> typing.List[nodes.BaseNode]: + name = name.lower() return [self.get_by_id(id_) for id_ in self.indexes_by_name_map.get(name, [])] @property diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 55aeb6f..74d6883 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -15,10 +15,12 @@ class Runner: def __init__(self, parser: Parser): self._parser = parser + self._internal_runners = {} self.reset() self._set_constants() self._set_input_columns() self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT)) + self._set_internal_runners() @classmethod def from_json(cls, data: typing.Union[str, dict], **kwargs): @@ -61,6 +63,19 @@ def _set_constants(self): for output_connector_name, _ in node.outputs: self._constants[f"{node.id}@{output_connector_name}"] = node.data.value + def _set_internal_runners(self): + for node_id in self.parser.indexes_by_name_map.get( + constants.FLOW_NODE_NAME, [] + ): + try: + self._internal_runners[node_id] = Runner.from_json( + self.parser.get_by_id(node_id).data.parsed_value() + ) + except Exception as e: + raise Exception( + f"Error setting internal runner for node {node_id}" + ) from e + @property def input_columns(self) -> dict: return self._input_columns @@ -108,7 +123,7 @@ def _create_initial_state_from_payload( return state_df def __get_input_params( - self, node_dict: dict, current_node_filter: pd.Series + self, node_id: str, node_dict: dict, current_node_filter: pd.Series ) -> dict: input_params = {} @@ -121,6 +136,9 @@ def __get_input_params( f"{connection['node']}@{connection['output']}", current_node_filter ) + if node_id in self._internal_runners: + input_params["runner"] = self._internal_runners[node_id] + return input_params def __set_state_data( @@ -151,7 +169,7 @@ def __run_node(self, node_id: str): return input_params = self.__get_input_params( - node.dict(by_alias=True), current_node_filter + node_id, node.dict(by_alias=True), current_node_filter ) output = node.run(**input_params) diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index e8aa0d9..4ed404e 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -10,7 +10,7 @@ from retrack.nodes.inputs import Input from retrack.nodes.logic import And, Not, Or from retrack.nodes.match import If -from retrack.nodes.math import AbsoluteValue, Math +from retrack.nodes.math import AbsoluteValue, Math, Round from retrack.nodes.outputs import Output from retrack.nodes.start import Start from retrack.nodes.startswith import StartsWith @@ -40,6 +40,7 @@ def register(name: str, node: BaseNode) -> None: register("Or", Or) register("Not", Not) register("Math", Math) +register("Round", Round) register("AbsoluteValue", AbsoluteValue) register("StartsWith", StartsWith) register("EndsWith", EndsWith) diff --git a/retrack/nodes/dynamic/__init__.py b/retrack/nodes/dynamic/__init__.py index a1caadd..5ae3890 100644 --- a/retrack/nodes/dynamic/__init__.py +++ b/retrack/nodes/dynamic/__init__.py @@ -2,6 +2,7 @@ from retrack.nodes.dynamic.base import BaseDynamicNode from retrack.nodes.dynamic.csv_table import csv_table_factory +from retrack.nodes.dynamic.flow import flow_factory from retrack.utils.registry import Registry _registry = Registry() @@ -21,5 +22,6 @@ def register( register("CSVTableV0", csv_table_factory) +register("FlowV0", flow_factory) __all__ = ["registry", "register", "BaseDynamicNode"] diff --git a/retrack/nodes/dynamic/csv_table.py b/retrack/nodes/dynamic/csv_table.py index bc786e8..b705258 100644 --- a/retrack/nodes/dynamic/csv_table.py +++ b/retrack/nodes/dynamic/csv_table.py @@ -44,7 +44,7 @@ def csv_table_factory( BaseCSVTableV0Model = BaseDynamicNode.with_fields("CSVTableV0", **models) - class CSVTableV0Model(BaseCSVTableV0Model): + class CSVTableV0(BaseCSVTableV0Model): def run(self, **kwargs) -> typing.Dict[str, typing.Any]: csv_df = self.data.df() @@ -78,4 +78,4 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]: return {"output_value": response_df[self.data.target]} - return CSVTableV0Model + return CSVTableV0 diff --git a/retrack/nodes/dynamic/flow.py b/retrack/nodes/dynamic/flow.py new file mode 100644 index 0000000..a8b8b00 --- /dev/null +++ b/retrack/nodes/dynamic/flow.py @@ -0,0 +1,58 @@ +import typing + +import json + +import pandas as pd +import pydantic + +from retrack.nodes.base import InputConnectionModel, OutputConnectionModel +from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode + + +class FlowV0MetadataModel(pydantic.BaseModel): + value: str + default: typing.Optional[str] = None + + def parsed_value(self) -> typing.Dict[str, typing.Any]: + return json.loads(self.value) + + +class FlowV0OutputsModel(pydantic.BaseModel): + output_value: OutputConnectionModel + + +def flow_factory( + inputs: typing.Dict[str, typing.Any], **kwargs +) -> typing.Type[BaseDynamicNode]: + input_fields = {} + + for name in inputs.keys(): + input_fields[name] = BaseDynamicNode.create_sub_field(InputConnectionModel) + + inputs_model = BaseDynamicIOModel.with_fields("FlowV0InputsModel", **input_fields) + + models = { + "inputs": BaseDynamicNode.create_sub_field(inputs_model), + "outputs": BaseDynamicNode.create_sub_field(FlowV0OutputsModel), + "data": BaseDynamicNode.create_sub_field(FlowV0MetadataModel), + } + + BaseFlowV0Model = BaseDynamicNode.with_fields("FlowV0", **models) + + class FlowV0(BaseFlowV0Model): + def run(self, **kwargs) -> typing.Dict[str, typing.Any]: + runner = kwargs.get("runner", None) + if runner is None: + raise ValueError("Missing runner") + + inputs_in_kwargs = {} + + for name, value in kwargs.items(): + if name.startswith("input_"): + inputs_in_kwargs[name[6:]] = value + + response = runner.execute(pd.DataFrame(inputs_in_kwargs)) + + return {"output_value": response["output"].values} + + return FlowV0 diff --git a/retrack/nodes/math.py b/retrack/nodes/math.py index 662fdde..049a59d 100644 --- a/retrack/nodes/math.py +++ b/retrack/nodes/math.py @@ -94,3 +94,19 @@ def run( input_value: pd.Series, ) -> typing.Dict[str, pd.Series]: return {"output_value": input_value.astype(float).abs()} + + +############################################################### +# Round Node +############################################################### + + +class Round(BaseNode): + inputs: AbsoluteValueInputsModel + outputs: MathOutputsModel + + def run( + self, + input_value: pd.Series, + ) -> typing.Dict[str, pd.Series]: + return {"output_value": input_value.astype(float).round(0).astype(int)} diff --git a/retrack/utils/constants.py b/retrack/utils/constants.py index 09d0a38..65d5663 100644 --- a/retrack/utils/constants.py +++ b/retrack/utils/constants.py @@ -3,3 +3,4 @@ NULL_SUFFIX = "_void" FILTER_SUFFIX = "_filter" INPUT_OUTPUT_VALUE_CONNECTOR_NAME = "output_value" +FLOW_NODE_NAME = "flowv0" diff --git a/tests/resources/round-node.json b/tests/resources/round-node.json new file mode 100644 index 0000000..3641ef7 --- /dev/null +++ b/tests/resources/round-node.json @@ -0,0 +1,199 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 2, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [ + { + "node": 3, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 0, + 0 + ], + "name": "Start" + }, + "2": { + "id": 2, + "data": { + "name": "var_a", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + 364.0234375, + -195.8359375 + ], + "name": "Input" + }, + "3": { + "id": 3, + "data": { + "name": "var_b", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_down_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + 363.8376785973601, + 108.76168656142642 + ], + "name": "Input" + }, + "4": { + "id": 4, + "data": { + "operator": "*" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 2, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 3, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 6, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 717.8109142505134, + -74.6040121470534 + ], + "name": "Math" + }, + "6": { + "id": 6, + "data": {}, + "inputs": { + "input_value": { + "connections": [ + { + "node": 4, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 7, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 1005.7455832447746, + -52.25458299824986 + ], + "name": "Round" + }, + "7": { + "id": 7, + "data": { + "message": null + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 6, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1279.293661227661, + -54.172931549720175 + ], + "name": "Output" + } + } +} \ No newline at end of file diff --git a/tests/resources/rule-of-rules.json b/tests/resources/rule-of-rules.json new file mode 100644 index 0000000..7acada2 --- /dev/null +++ b/tests/resources/rule-of-rules.json @@ -0,0 +1,169 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 3, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [ + { + "node": 4, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 0, + 0 + ], + "name": "Start" + }, + "2": { + "id": 2, + "data": { + "value": "{\n\t\"id\": \"demo@0.1.0\",\n\t\"nodes\": {\n\t\t\"0\": {\n\t\t\t\"id\": 0,\n\t\t\t\"data\": {},\n\t\t\t\"inputs\": {},\n\t\t\t\"outputs\": {\n\t\t\t\t\"output_up_void\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 2,\n\t\t\t\t\t\t\t\"input\": \"input_void\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t},\n\t\t\t\t\"output_down_void\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 3,\n\t\t\t\t\t\t\t\"input\": \"input_void\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"position\": [\n\t\t\t\t0,\n\t\t\t\t0\n\t\t\t],\n\t\t\t\"name\": \"Start\"\n\t\t},\n\t\t\"2\": {\n\t\t\t\"id\": 2,\n\t\t\t\"data\": {\n\t\t\t\t\"name\": \"var_a\",\n\t\t\t\t\"default\": null\n\t\t\t},\n\t\t\t\"inputs\": {\n\t\t\t\t\"input_void\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 0,\n\t\t\t\t\t\t\t\"output\": \"output_up_void\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"outputs\": {\n\t\t\t\t\"output_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 4,\n\t\t\t\t\t\t\t\"input\": \"input_value_0\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"position\": [\n\t\t\t\t364.0234375,\n\t\t\t\t-195.8359375\n\t\t\t],\n\t\t\t\"name\": \"Input\"\n\t\t},\n\t\t\"3\": {\n\t\t\t\"id\": 3,\n\t\t\t\"data\": {\n\t\t\t\t\"name\": \"var_b\",\n\t\t\t\t\"default\": null\n\t\t\t},\n\t\t\t\"inputs\": {\n\t\t\t\t\"input_void\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 0,\n\t\t\t\t\t\t\t\"output\": \"output_down_void\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"outputs\": {\n\t\t\t\t\"output_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 4,\n\t\t\t\t\t\t\t\"input\": \"input_value_1\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"position\": [\n\t\t\t\t363.8376785973601,\n\t\t\t\t108.76168656142642\n\t\t\t],\n\t\t\t\"name\": \"Input\"\n\t\t},\n\t\t\"4\": {\n\t\t\t\"id\": 4,\n\t\t\t\"data\": {\n\t\t\t\t\"operator\": \"*\"\n\t\t\t},\n\t\t\t\"inputs\": {\n\t\t\t\t\"input_value_0\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 2,\n\t\t\t\t\t\t\t\"output\": \"output_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t},\n\t\t\t\t\"input_value_1\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 3,\n\t\t\t\t\t\t\t\"output\": \"output_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"outputs\": {\n\t\t\t\t\"output_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 6,\n\t\t\t\t\t\t\t\"input\": \"input_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"position\": [\n\t\t\t\t717.8109142505134,\n\t\t\t\t-74.6040121470534\n\t\t\t],\n\t\t\t\"name\": \"Math\"\n\t\t},\n\t\t\"6\": {\n\t\t\t\"id\": 6,\n\t\t\t\"data\": {},\n\t\t\t\"inputs\": {\n\t\t\t\t\"input_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 4,\n\t\t\t\t\t\t\t\"output\": \"output_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"outputs\": {\n\t\t\t\t\"output_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 7,\n\t\t\t\t\t\t\t\"input\": \"input_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"position\": [\n\t\t\t\t1005.7455832447746,\n\t\t\t\t-52.25458299824986\n\t\t\t],\n\t\t\t\"name\": \"Round\"\n\t\t},\n\t\t\"7\": {\n\t\t\t\"id\": 7,\n\t\t\t\"data\": {\n\t\t\t\t\"message\": null\n\t\t\t},\n\t\t\t\"inputs\": {\n\t\t\t\t\"input_value\": {\n\t\t\t\t\t\"connections\": [\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\"node\": 6,\n\t\t\t\t\t\t\t\"output\": \"output_value\",\n\t\t\t\t\t\t\t\"data\": {}\n\t\t\t\t\t\t}\n\t\t\t\t\t]\n\t\t\t\t}\n\t\t\t},\n\t\t\t\"outputs\": {},\n\t\t\t\"position\": [\n\t\t\t\t1279.293661227661,\n\t\t\t\t-54.172931549720175\n\t\t\t],\n\t\t\t\"name\": \"Output\"\n\t\t}\n\t}\n}", + "default": null + }, + "inputs": { + "input_var_a": { + "connections": [ + { + "node": 3, + "output": "output_value", + "data": {} + } + ] + }, + "input_var_b": { + "connections": [ + { + "node": 4, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 5, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 588.546875, + -135.41796875 + ], + "name": "FlowV0" + }, + "3": { + "id": 3, + "data": { + "name": "example_a", + "default": "" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 2, + "input": "input_var_a", + "data": {} + } + ] + } + }, + "position": [ + 292.11328125, + -192.6953125 + ], + "name": "Input" + }, + "4": { + "id": 4, + "data": { + "name": "var_b", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_down_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 2, + "input": "input_var_b", + "data": {} + } + ] + } + }, + "position": [ + 308.3056102761418, + 98.85067813963023 + ], + "name": "Input" + }, + "5": { + "id": 5, + "data": { + "message": null + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 2, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 853.1862787404114, + -109.19324694397277 + ], + "name": "Output" + } + } +} \ No newline at end of file diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index 32376ec..d41bc64 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -97,6 +97,29 @@ def test_flows(filename, in_values, expected_out_values): {"message": "valid age", "output": True}, ], ), + ( + "round-node", + [{"var_a": 1.1, "var_b": 1.5}, {"var_a": 3.6, "var_b": 2.1}], + [ + {"output": 2, "message": None}, + {"output": 8, "message": None}, + ], + ), + ( + "rule-of-rules", + [ + {"example_a": 1, "var_b": 2}, + {"example_a": 3, "var_b": 2}, + {"example_a": 4, "var_b": 2}, + {"example_a": 5, "var_b": 5}, + ], + [ + {"output": 2, "message": None}, + {"output": 6, "message": None}, + {"output": 8, "message": None}, + {"output": 25, "message": None}, + ], + ), ], ) def test_create_from_json(filename, in_values, expected_out_values): diff --git a/tests/test_nodes/test_math.py b/tests/test_nodes/test_math.py index a916439..7ae9fd6 100644 --- a/tests/test_nodes/test_math.py +++ b/tests/test_nodes/test_math.py @@ -1,7 +1,7 @@ import pandas as pd import pytest -from retrack.nodes.math import AbsoluteValue, Math, MathOperator +from retrack.nodes.math import AbsoluteValue, Math, MathOperator, Round @pytest.fixture @@ -75,3 +75,9 @@ def test_absolute_value_node_run(absolute_value_input_data): absolute_value_node = AbsoluteValue(**absolute_value_input_data) output = absolute_value_node.run(pd.Series(["-1", "1", "0", "-2"])) assert (output["output_value"] == pd.Series([1, 1, 0, 2])).all() + + +def test_round_node_run(absolute_value_input_data): + round_node = Round(**absolute_value_input_data) + output = round_node.run(pd.Series(["-1.5", "1.5", "0", "-2.5"])) + assert (output["output_value"] == pd.Series([-2, 2, 0, -2])).all()