Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F/flow node #14

Merged
merged 2 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions retrack/engine/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _set_indexes_by_name_map(self):
self._indexes_by_name_map[node_name].append(node_id)

def get_by_name(self, name: str) -> typing.List[nodes.BaseNode]:
name = name.lower()
return [self.get_by_id(id_) for id_ in self.indexes_by_name_map.get(name, [])]

@property
Expand Down
22 changes: 20 additions & 2 deletions retrack/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
class Runner:
def __init__(self, parser: Parser):
self._parser = parser
self._internal_runners = {}
self.reset()
self._set_constants()
self._set_input_columns()
self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT))
self._set_internal_runners()

@classmethod
def from_json(cls, data: typing.Union[str, dict], **kwargs):
Expand Down Expand Up @@ -61,6 +63,19 @@ def _set_constants(self):
for output_connector_name, _ in node.outputs:
self._constants[f"{node.id}@{output_connector_name}"] = node.data.value

def _set_internal_runners(self):
for node_id in self.parser.indexes_by_name_map.get(
constants.FLOW_NODE_NAME, []
):
try:
self._internal_runners[node_id] = Runner.from_json(
self.parser.get_by_id(node_id).data.parsed_value()
)
except Exception as e:
raise Exception(
f"Error setting internal runner for node {node_id}"
) from e

@property
def input_columns(self) -> dict:
return self._input_columns
Expand Down Expand Up @@ -108,7 +123,7 @@ def _create_initial_state_from_payload(
return state_df

def __get_input_params(
self, node_dict: dict, current_node_filter: pd.Series
self, node_id: str, node_dict: dict, current_node_filter: pd.Series
) -> dict:
input_params = {}

Expand All @@ -121,6 +136,9 @@ def __get_input_params(
f"{connection['node']}@{connection['output']}", current_node_filter
)

if node_id in self._internal_runners:
input_params["runner"] = self._internal_runners[node_id]

return input_params

def __set_state_data(
Expand Down Expand Up @@ -151,7 +169,7 @@ def __run_node(self, node_id: str):
return

input_params = self.__get_input_params(
node.dict(by_alias=True), current_node_filter
node_id, node.dict(by_alias=True), current_node_filter
)
output = node.run(**input_params)

Expand Down
3 changes: 2 additions & 1 deletion retrack/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from retrack.nodes.inputs import Input
from retrack.nodes.logic import And, Not, Or
from retrack.nodes.match import If
from retrack.nodes.math import AbsoluteValue, Math
from retrack.nodes.math import AbsoluteValue, Math, Round
from retrack.nodes.outputs import Output
from retrack.nodes.start import Start
from retrack.nodes.startswith import StartsWith
Expand Down Expand Up @@ -40,6 +40,7 @@ def register(name: str, node: BaseNode) -> None:
register("Or", Or)
register("Not", Not)
register("Math", Math)
register("Round", Round)
register("AbsoluteValue", AbsoluteValue)
register("StartsWith", StartsWith)
register("EndsWith", EndsWith)
Expand Down
2 changes: 2 additions & 0 deletions retrack/nodes/dynamic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from retrack.nodes.dynamic.base import BaseDynamicNode
from retrack.nodes.dynamic.csv_table import csv_table_factory
from retrack.nodes.dynamic.flow import flow_factory
from retrack.utils.registry import Registry

_registry = Registry()
Expand All @@ -21,5 +22,6 @@ def register(


register("CSVTableV0", csv_table_factory)
register("FlowV0", flow_factory)

__all__ = ["registry", "register", "BaseDynamicNode"]
4 changes: 2 additions & 2 deletions retrack/nodes/dynamic/csv_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def csv_table_factory(

BaseCSVTableV0Model = BaseDynamicNode.with_fields("CSVTableV0", **models)

class CSVTableV0Model(BaseCSVTableV0Model):
class CSVTableV0(BaseCSVTableV0Model):
def run(self, **kwargs) -> typing.Dict[str, typing.Any]:
csv_df = self.data.df()

Expand Down Expand Up @@ -78,4 +78,4 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]:

return {"output_value": response_df[self.data.target]}

return CSVTableV0Model
return CSVTableV0
58 changes: 58 additions & 0 deletions retrack/nodes/dynamic/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import typing

import json

import pandas as pd
import pydantic

from retrack.nodes.base import InputConnectionModel, OutputConnectionModel
from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode


class FlowV0MetadataModel(pydantic.BaseModel):
value: str
default: typing.Optional[str] = None

def parsed_value(self) -> typing.Dict[str, typing.Any]:
return json.loads(self.value)


class FlowV0OutputsModel(pydantic.BaseModel):
output_value: OutputConnectionModel


def flow_factory(
inputs: typing.Dict[str, typing.Any], **kwargs
) -> typing.Type[BaseDynamicNode]:
input_fields = {}

for name in inputs.keys():
input_fields[name] = BaseDynamicNode.create_sub_field(InputConnectionModel)

inputs_model = BaseDynamicIOModel.with_fields("FlowV0InputsModel", **input_fields)

models = {
"inputs": BaseDynamicNode.create_sub_field(inputs_model),
"outputs": BaseDynamicNode.create_sub_field(FlowV0OutputsModel),
"data": BaseDynamicNode.create_sub_field(FlowV0MetadataModel),
}

BaseFlowV0Model = BaseDynamicNode.with_fields("FlowV0", **models)

class FlowV0(BaseFlowV0Model):
def run(self, **kwargs) -> typing.Dict[str, typing.Any]:
runner = kwargs.get("runner", None)
if runner is None:
raise ValueError("Missing runner")

inputs_in_kwargs = {}

for name, value in kwargs.items():
if name.startswith("input_"):
inputs_in_kwargs[name[6:]] = value

response = runner.execute(pd.DataFrame(inputs_in_kwargs))

return {"output_value": response["output"].values}

return FlowV0
16 changes: 16 additions & 0 deletions retrack/nodes/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,19 @@ def run(
input_value: pd.Series,
) -> typing.Dict[str, pd.Series]:
return {"output_value": input_value.astype(float).abs()}


###############################################################
# Round Node
###############################################################


class Round(BaseNode):
inputs: AbsoluteValueInputsModel
outputs: MathOutputsModel

def run(
self,
input_value: pd.Series,
) -> typing.Dict[str, pd.Series]:
return {"output_value": input_value.astype(float).round(0).astype(int)}
1 change: 1 addition & 0 deletions retrack/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
NULL_SUFFIX = "_void"
FILTER_SUFFIX = "_filter"
INPUT_OUTPUT_VALUE_CONNECTOR_NAME = "output_value"
FLOW_NODE_NAME = "flowv0"
Loading
Loading