Skip to content

Commit

Permalink
feat!: Update serialisation schema, implement CustomConst serialisa…
Browse files Browse the repository at this point in the history
…tion (#1005)

We fix `model_rebuild` in `tys.py` to update the `model_config` rather
than overwrite it. This prevents our config, i.e.
`json_scheme_extra.required` from being removed during a
`model_rebuild`.

We remove most `json_scheme_extra.required` from the schema, using these
only for `RootModel`s. This allows us to remove `TaggedSumType`, as well
as alleviating us from the need of introducing `TaggedOpaqueType`.

The serialisation schema is updated, and is `proptest`ed in #981.
Reviewers should verify that the `serde` annotations, the pydantic
schema definition, and the generated schemas exactly match that branch.

BREAKING CHANGE: 
* Serialization schema
* `Const::const_type` and `Value::const_type` are renamed to
`Const::get_type` and `Value::get_type`. These now match several other
`get_type` functions.

---------

Co-authored-by: Craig Roy <croyzor@users.noreply.github.com>
  • Loading branch information
doug-q and croyzor committed May 9, 2024
1 parent fae2993 commit c45e6fc
Show file tree
Hide file tree
Showing 25 changed files with 957 additions and 448 deletions.
71 changes: 26 additions & 45 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,52 +70,41 @@ class FuncDecl(BaseOp):
signature: PolyFuncType


CustomConst = Any # TODO
class CustomConst(ConfiguredBaseModel):
c: str
v: Any


class ExtensionValue(ConfiguredBaseModel):
"""An extension constant value, that can check it is of a given [CustomType]."""

c: Literal["Extension"] = Field("Extension", title="ValueTag")
e: CustomConst = Field(title="CustomConst")

class Config:
json_schema_extra = {
"required": ["c", "e"],
}
v: Literal["Extension"] = Field("Extension", title="ValueTag")
extensions: ExtensionSet
typ: Type
value: CustomConst


class FunctionValue(ConfiguredBaseModel):
"""A higher-order function value."""

c: Literal["Function"] = Field("Function", title="ValueTag")
v: Literal["Function"] = Field("Function", title="ValueTag")
hugr: Any # TODO

class Config:
json_schema_extra = {
"required": ["c", "hugr"],
}


class TupleValue(ConfiguredBaseModel):
"""A constant tuple value."""

c: Literal["Tuple"] = Field("Tuple", title="ValueTag")
v: Literal["Tuple"] = Field("Tuple", title="ValueTag")
vs: list["Value"]

class Config:
json_schema_extra = {
"required": ["c", "vs"],
}


class SumValue(ConfiguredBaseModel):
"""A Sum variant
For any Sum type where this value meets the type of the variant indicated by the tag
"""

c: Literal["Sum"] = Field("Sum", title="ValueTag")
v: Literal["Sum"] = Field("Sum", title="ValueTag")
tag: int
typ: SumType
vs: list["Value"]
Expand All @@ -127,29 +116,26 @@ class Config:
"A Sum variant For any Sum type where this value meets the type "
"of the variant indicated by the tag."
),
"required": ["c", "tag", "typ", "vs"],
}


class Value(RootModel):
"""A constant Value."""

root: ExtensionValue | FunctionValue | TupleValue | SumValue = Field(
discriminator="c"
discriminator="v"
)

class Config:
json_schema_extra = {"required": ["v"]}


class Const(BaseOp):
"""A Const operation definition."""

op: Literal["Const"] = "Const"
v: Value = Field()

class Config:
json_schema_extra = {
"required": ["op", "parent", "v"],
}


# -----------------------------------------------
# --------------- BasicBlock types ------------------
Expand All @@ -163,7 +149,7 @@ class DataflowBlock(BaseOp):
op: Literal["DataflowBlock"] = "DataflowBlock"
inputs: TypeRow = Field(default_factory=list)
other_outputs: TypeRow = Field(default_factory=list)
sum_rows: list[TypeRow] = Field(default_factory=list)
sum_rows: list[TypeRow]
extension_delta: ExtensionSet = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -173,26 +159,18 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0].root
assert isinstance(pred, tys.TaggedSumType)
if isinstance(pred.st.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.st.root.size)]
assert isinstance(pred, tys.SumType)
if isinstance(pred.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.root.size)]
else:
self.sum_rows = []
for variant in pred.st.root.rows:
for variant in pred.root.rows:
self.sum_rows.append(variant)
self.other_outputs = outputs[1:]

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"required": [
"parent",
"op",
"inputs",
"other_outputs",
"sum_rows",
"extension_delta",
],
"description": "A CFG basic block node. The signature is that of the internal Dataflow graph.",
}

Expand All @@ -205,9 +183,9 @@ class ExitBlock(BaseOp):
cfg_outputs: TypeRow

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output."
# Needed to avoid random '\n's in the pydantic description
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output.",
}


Expand Down Expand Up @@ -334,8 +312,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# First port is a predicate, i.e. a sum of tuple types. We need to unpack
# those into a list of type rows
pred = in_types[0]
assert isinstance(pred.root, tys.TaggedSumType)
sum = pred.root.st.root
assert isinstance(pred.root, tys.SumType)
sum = pred.root.root
if isinstance(sum, tys.UnitSum):
self.sum_rows = [[] for _ in range(sum.size)]
else:
Expand Down Expand Up @@ -513,6 +491,9 @@ class OpType(RootModel):
| AliasDefn
) = Field(discriminator="op")

class Config:
json_schema_extra = {"required": ["parent", "op"]}


# --------------------------------------
# --------------- OpDef ----------------
Expand Down
3 changes: 0 additions & 3 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,3 @@ def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "HugrTesting"
48 changes: 25 additions & 23 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ConfiguredBaseModel(BaseModel):
model_config = default_model_config

@classmethod
def set_model_config(cls, config: ConfigDict):
cls.model_config = config
def update_model_config(cls, config: ConfigDict):
cls.model_config.update(config)


# --------------------------------------------
Expand Down Expand Up @@ -99,6 +99,9 @@ class TypeParam(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

class Config:
json_schema_extra = {"required": ["tp"]}


# ------------------------------------------
# --------------- TypeArg ------------------
Expand Down Expand Up @@ -150,6 +153,9 @@ class TypeArg(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tya")

class Config:
json_schema_extra = {"required": ["tya"]}


# --------------------------------------------
# --------------- Container ------------------
Expand All @@ -170,24 +176,29 @@ class Array(MultiContainer):
class UnitSum(ConfiguredBaseModel):
"""Simple sum type where all variants are empty tuples."""

t: Literal["Sum"] = "Sum"
s: Literal["Unit"] = "Unit"
size: int


class GeneralSum(ConfiguredBaseModel):
"""General sum type that explicitly stores the types of the variants."""

t: Literal["Sum"] = "Sum"
s: Literal["General"] = "General"
rows: list["TypeRow"]


class SumType(RootModel):
root: Union[UnitSum, GeneralSum] = Field(discriminator="s")
root: Annotated[Union[UnitSum, GeneralSum], Field(discriminator="s")]

# This seems to be required for nested discriminated unions to work
@property
def t(self) -> str:
return self.root.t

class TaggedSumType(ConfiguredBaseModel):
t: Literal["Sum"] = "Sum"
st: SumType
class Config:
json_schema_extra = {"required": ["s"]}


# ----------------------------------------------
Expand Down Expand Up @@ -280,17 +291,13 @@ def join(*bs: "TypeBound") -> "TypeBound":
class Opaque(ConfiguredBaseModel):
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
id: str # Unique identifier of the opaque type.
args: list[TypeArg]
bound: TypeBound


class TaggedOpaque(ConfiguredBaseModel):
t: Literal["Opaque"] = "Opaque"
o: Opaque


class Alias(ConfiguredBaseModel):
"""An Alias Type"""

Expand All @@ -314,16 +321,13 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit
| Variable
| USize
| FunctionType
| Array
| TaggedSumType
| TaggedOpaque
| Alias,
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque | Alias,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")
Field(discriminator="t"),
]

class Config:
json_schema_extra = {"required": ["t"]}


# -------------------------------------------
Expand Down Expand Up @@ -365,11 +369,9 @@ def model_rebuild(
config: ConfigDict = ConfigDict(),
**kwargs,
):
new_config = default_model_config.copy()
new_config.update(config)
for c in classes.values():
if issubclass(c, ConfiguredBaseModel):
c.set_model_config(new_config)
c.update_model_config(config)
c.model_rebuild(**kwargs)


Expand Down
4 changes: 2 additions & 2 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::const_type).collect_vec();
let const_types = consts.iter().map(Value::get_type).collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
Expand Down Expand Up @@ -338,7 +338,7 @@ mod test {
let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into();
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![list.const_type().clone()],
vec![list.get_type().clone()],
))
.unwrap();

Expand Down
1 change: 0 additions & 1 deletion hugr/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
//! # }
//! # doctest().unwrap();
//! ```
//!
use thiserror::Error;

use crate::extension::SignatureError;
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ pub trait Dataflow: Container {
let load_n = self
.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
datatype: op.get_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
Expand Down
38 changes: 18 additions & 20 deletions hugr/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ use crate::{
/// +------------+
/// */
/// use hugr::{
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// Hugr,
/// extension::{ExtensionSet, prelude},
/// types::{FunctionType, Type, SumType},
/// ops,
/// type_row,
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// extension::{prelude, ExtensionSet},
/// ops, type_row,
/// types::{FunctionType, SumType, Type},
/// Hugr,
/// };
///
/// const NAT: Type = prelude::USIZE_T;
Expand All @@ -75,7 +74,7 @@ use crate::{
/// let left_42 = ops::Value::sum(
/// 0,
/// [prelude::ConstUsize::new(42).into()],
/// SumType::new(sum_variants.clone())
/// SumType::new(sum_variants.clone()),
/// )?;
/// let sum = entry_b.add_load_value(left_42);
///
Expand All @@ -85,11 +84,10 @@ use crate::{
/// // This block will be the first successor of the entry node. It takes two
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder =
/// cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// 1 // only one successor to this block
/// )?;
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
/// // This block has one successor. The choice is denoted by a unary sum.
/// let sum_unary = successor_builder.add_load_const(ops::Value::unary_unit_sum());
Expand All @@ -100,14 +98,14 @@ use crate::{
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder =
/// cfg_builder.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
/// let exit = cfg_builder.exit_block();
/// cfg_builder.branch(&entry, 0, &successor_a)?; // branch 0 goes to successor_a
/// cfg_builder.branch(&entry, 1, &successor_b)?; // branch 1 goes to successor_b
Expand Down
Loading

0 comments on commit c45e6fc

Please sign in to comment.