Skip to content

Commit

Permalink
fix: input_port_types and other helper functions on pydantic schema (#…
Browse files Browse the repository at this point in the history
…958)

Fixes several schema issue. Closes #986
  • Loading branch information
mark-koch committed May 2, 2024
1 parent 954b2cb commit 8651839
Show file tree
Hide file tree
Showing 6 changed files with 1,441 additions and 245 deletions.
70 changes: 25 additions & 45 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from abc import ABC
from typing import Any, Literal, cast
from typing import Any, Literal

from pydantic import Field, RootModel

Expand All @@ -28,7 +28,7 @@ class BaseOp(ABC, ConfiguredBaseModel):

# Parent node index of node the op belongs to, used only at serialization time
parent: NodeID
input_extensions: ExtensionSet = Field(default_factory=ExtensionSet)
input_extensions: ExtensionSet | None = Field(default=None)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
"""Hook to insert type information from the input and output ports into the
Expand Down Expand Up @@ -59,29 +59,15 @@ class FuncDefn(BaseOp):
op: Literal["FuncDefn"] = "FuncDefn"

name: str
signature: PolyFuncType = Field(default_factory=PolyFuncType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
assert len(out_types) == 1
out = out_types[0]
assert isinstance(out, PolyFuncType)
self.signature = out # TODO: Extensions
signature: PolyFuncType


class FuncDecl(BaseOp):
"""External function declaration, linked at runtime."""

op: Literal["FuncDecl"] = "FuncDecl"
name: str
signature: PolyFuncType = Field(default_factory=PolyFuncType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
assert len(out_types) == 1
out = out_types[0]
assert isinstance(out, PolyFuncType)
self.signature = out
signature: PolyFuncType


CustomConst = Any # TODO
Expand Down Expand Up @@ -186,13 +172,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:

def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0]
assert isinstance(pred, tys.UnitSum | tys.GeneralSum)
if isinstance(pred, tys.UnitSum):
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
pred = outputs[0].root
assert isinstance(pred, tys.TaggedSumType)
if isinstance(pred.st.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.st.root.size)]
else:
self.sum_rows = []
for variant in pred.rows:
for variant in pred.st.root.rows:
self.sum_rows.append(variant)
self.other_outputs = outputs[1:]

Expand Down Expand Up @@ -266,16 +252,9 @@ class Call(DataflowOp):
"""

op: Literal["Call"] = "Call"
func_sig: PolyFuncType = Field(default_factory=FunctionType.empty)
type_args: list[tys.TypeArg] = Field(default_factory=list)
instantiation: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
fun_ty = in_types[-1]
assert isinstance(fun_ty, PolyFuncType)
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
self.signature = poly_func.body
func_sig: PolyFuncType
type_args: list[tys.TypeArg]
instantiation: FunctionType

class Config:
# Needed to avoid random '\n's in the pydantic description
Expand All @@ -292,19 +271,18 @@ class Config:
class CallIndirect(DataflowOp):
"""Call a function indirectly.
Like call, but the first input is a standard dataflow graph type."""
Like call, but the first input is a standard dataflow graph type.
"""

op: Literal["CallIndirect"] = "CallIndirect"
signature: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
fun_ty = in_types[0]
assert isinstance(fun_ty, PolyFuncType)
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
assert len(poly_func.body.input) == len(in_types) - 1
assert len(poly_func.body.output) == len(out_types)
self.signature = poly_func.body
fun_ty = in_types[0].root
assert isinstance(fun_ty, FunctionType)
assert len(fun_ty.input) == len(in_types) - 1
assert len(fun_ty.output) == len(out_types)
self.signature = fun_ty


class LoadConstant(DataflowOp):
Expand Down Expand Up @@ -359,12 +337,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# First port is a predicate, i.e. a sum of tuple types. We need to unpack
# those into a list of type rows
pred = in_types[0]
if isinstance(pred, tys.UnitSum):
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
assert isinstance(pred.root, tys.TaggedSumType)
sum = pred.root.st.root
if isinstance(sum, tys.UnitSum):
self.sum_rows = [[] for _ in range(sum.size)]
else:
assert isinstance(pred, tys.GeneralSum)
assert isinstance(sum, tys.GeneralSum)
self.sum_rows = []
for ty in pred.rows:
for ty in sum.rows:
self.sum_rows.append(ty)
self.other_inputs = list(in_types[1:])
self.outputs = list(out_types)
Expand Down
12 changes: 4 additions & 8 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union, Mapping
from typing import Annotated, Any, Literal, Union, Mapping

from pydantic import (
BaseModel,
Expand All @@ -28,6 +28,7 @@ def _json_custom_error_validator(
Used to define named recursive alias types.
"""
return handler(value)
try:
return handler(value)
except ValidationError as err:
Expand All @@ -38,6 +39,7 @@ def _json_custom_error_validator(


ExtensionId = str
ExtensionSet = list[ExtensionId]

default_model_config = ConfigDict()

Expand All @@ -50,12 +52,6 @@ def set_model_config(cls, config: ConfigDict):
cls.model_config = config


class ExtensionSet(RootModel):
"""A set of extensions ids."""

root: Optional[list[ExtensionId]] = Field(default=None)


# --------------------------------------------
# --------------- TypeParam ------------------
# --------------------------------------------
Expand Down Expand Up @@ -219,7 +215,7 @@ class FunctionType(ConfiguredBaseModel):

@classmethod
def empty(cls) -> "FunctionType":
return FunctionType(input=[], output=[], extension_reqs=ExtensionSet([]))
return FunctionType(input=[], output=[], extension_reqs=[])

class Config:
# Needed to avoid random '\n's in the pydantic description
Expand Down
Loading

0 comments on commit 8651839

Please sign in to comment.