Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add LoadFunction node #947

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,18 @@ class LoadConstant(DataflowOp):
datatype: Type


class LoadFunction(DataflowOp):
"""Load a static function in to the local dataflow graph."""

op: Literal["LoadFunction"] = "LoadFunction"
func_sig: PolyFuncType
type_args: list[tys.TypeArg]
signature: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.signature = FunctionType(input=list(in_types), output=list(out_types))


class DFG(DataflowOp):
"""A simply nested dataflow graph."""

Expand Down Expand Up @@ -502,6 +514,7 @@ class OpType(RootModel):
| Call
| CallIndirect
| LoadConstant
| LoadFunction
| CustomOp
| Noop
| MakeTuple
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class FunctionType(BaseModel):
input: "TypeRow" # Value inputs of the function.
output: "TypeRow" # Value outputs of the function.
# The extension requirements which are added by the operation
extension_reqs: "ExtensionSet" = Field(default_factory=list)
extension_reqs: ExtensionSet = Field(default_factory=ExtensionSet)

@classmethod
def empty(cls) -> "FunctionType":
Expand Down
34 changes: 34 additions & 0 deletions hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,40 @@ pub trait Dataflow: Container {
self.add_load_const(constant.into())
}

/// Load a static function and return the local dataflow wire for that function.
/// Adds a [`OpType::LoadFunction`] node.
///
/// The `DEF` const generic is used to indicate whether the function is defined
/// or just declared.
fn load_func<const DEFINED: bool>(
&mut self,
fid: &FuncID<DEFINED>,
type_args: &[TypeArg],
// Sadly required as we substituting in type_args may result in recomputing bounds of types:
exts: &ExtensionRegistry,
) -> Result<Wire, BuildError> {
let func_node = fid.node();
let func_op = self.hugr().get_nodetype(func_node).op();
let func_sig = match func_op {
OpType::FuncDefn(ops::FuncDefn { signature, .. })
| OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
_ => {
return Err(BuildError::UnexpectedType {
node: func_node,
op_desc: "FuncDecl/FuncDefn",
})
}
};

let load_n = self.add_dataflow_op(
ops::LoadFunction::try_new(func_sig, type_args, exts)?,
// Static wire from the function node
vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
)?;

Ok(load_n.out_wire(0))
}

/// Return a builder for a [`crate::ops::TailLoop`] node.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
Expand Down
11 changes: 11 additions & 0 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ pub enum SignatureError {
cached: FunctionType,
expected: FunctionType,
},
/// The result of the type application stored in a [LoadFunction]
/// is not what we get by applying the type-args to the polymorphic function
///
/// [LoadFunction]: crate::ops::dataflow::LoadFunction
#[error(
"Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
)]
LoadFunctionIncorrectlyAppliesType {
cached: FunctionType,
expected: FunctionType,
},
}

/// Concrete instantiations of types and operations defined in extensions.
Expand Down
4 changes: 4 additions & 0 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
c.validate(self.extension_registry)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
OpType::LoadFunction(c) => {
c.validate(self.extension_registry)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
_ => (),
}

Expand Down
21 changes: 21 additions & 0 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,27 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn test_polymorphic_load() -> Result<(), Box<dyn std::error::Error>> {
let mut m = ModuleBuilder::new();
let id = m.declare(
"id",
PolyFuncType::new(
vec![TypeBound::Any.into()],
FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
),
)?;
let sig = FunctionType::new(
vec![],
vec![Type::new_function(FunctionType::new_endo(vec![USIZE_T]))],
);
let mut f = m.define_function("main", sig.into())?;
let l = f.load_func(&id, &[USIZE_T.into()], &PRELUDE_REGISTRY)?;
f.finish_with_outputs([l])?;
let _ = m.finish_prelude_hugr()?;
Ok(())
}

#[cfg(feature = "extension_inference")]
mod extension_tests {
use super::*;
Expand Down
10 changes: 8 additions & 2 deletions hugr/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use enum_dispatch::enum_dispatch;
pub use constant::{Const, Value};
pub use controlflow::{BasicBlock, Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use custom::CustomOp;
pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG};
pub use dataflow::{
Call, CallIndirect, DataflowParent, Input, LoadConstant, LoadFunction, Output, DFG,
};
pub use leaf::{Lift, MakeTuple, Noop, Tag, UnpackTuple};
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
use smol_str::SmolStr;
Expand All @@ -47,6 +49,7 @@ pub enum OpType {
Call,
CallIndirect,
LoadConstant,
LoadFunction,
DFG,
CustomOp,
Noop,
Expand Down Expand Up @@ -105,6 +108,7 @@ impl_op_ref_try_into!(Output);
impl_op_ref_try_into!(Call);
impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(LoadFunction);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(CustomOp);
impl_op_ref_try_into!(Noop);
Expand Down Expand Up @@ -226,7 +230,8 @@ impl OpType {
Some(Port::new(dir, self.value_port_count(dir)))
}

/// If the op has a static input ([`Call`] and [`LoadConstant`]), the port of that input.
/// If the op has a static input ([`Call`], [`LoadConstant`], and [`LoadFunction`]), the port of
/// that input.
#[inline]
pub fn static_input_port(&self) -> Option<IncomingPort> {
self.static_port(Direction::Incoming)
Expand Down Expand Up @@ -419,6 +424,7 @@ impl OpParent for Output {}
impl OpParent for Call {}
impl OpParent for CallIndirect {}
impl OpParent for LoadConstant {}
impl OpParent for LoadFunction {}
impl OpParent for CustomOp {}
impl OpParent for Noop {}
impl OpParent for MakeTuple {}
Expand Down
80 changes: 80 additions & 0 deletions hugr/src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,86 @@ impl LoadConstant {
}
}

/// Load a static function in to the local dataflow graph.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct LoadFunction {
/// Signature of the function
func_sig: PolyFuncType,
type_args: Vec<TypeArg>,
signature: FunctionType, // Cache, so we can fail in try_new() not in signature()
}
impl_op_name!(LoadFunction);
impl DataflowOpTrait for LoadFunction {
const TAG: OpTag = OpTag::LoadFunc;

fn description(&self) -> &str {
"Load a static function in to the local dataflow graph"
}

fn signature(&self) -> FunctionType {
self.signature.clone()
}

fn static_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::Function(self.func_sig.clone()))
}
}
impl LoadFunction {
/// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit
/// the [TypeParam]s declared by the function.
///
/// [TypeParam]: crate::types::type_param::TypeParam
pub fn try_new(
func_sig: PolyFuncType,
type_args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let type_args = type_args.into();
let instantiation = func_sig.instantiate(&type_args, exts)?;
let signature = FunctionType::new(TypeRow::new(), vec![Type::new_function(instantiation)]);
Ok(Self {
func_sig,
type_args,
signature,
})
}

#[inline]
/// Return the type of the function loaded by this op.
pub fn function_type(&self) -> &PolyFuncType {
&self.func_sig
}

/// The IncomingPort which links to the loaded function.
///
/// This matches [`OpType::static_input_port`].
///
/// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
doug-q marked this conversation as resolved.
Show resolved Hide resolved
#[inline]
pub fn function_port(&self) -> IncomingPort {
0.into()
}

pub(crate) fn validate(
&self,
extension_registry: &ExtensionRegistry,
) -> Result<(), SignatureError> {
let other = Self::try_new(
self.func_sig.clone(),
self.type_args.clone(),
extension_registry,
)?;
if other.signature == self.signature {
Ok(())
} else {
Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
cached: self.signature.clone(),
expected: other.signature.clone(),
})
}
}
}

/// Operations that is the parent of a dataflow graph.
pub trait DataflowParent {
/// Signature of the inner dataflow graph.
Expand Down
6 changes: 5 additions & 1 deletion hugr/src/ops/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub enum OpTag {
FnCall,
/// A constant load operation.
LoadConst,
/// A function load operation.
LoadFunc,
/// A definition that could be at module level or inside a DSG.
ScopedDefn,
/// A tail-recursive loop.
Expand Down Expand Up @@ -129,6 +131,7 @@ impl OpTag {
OpTag::StaticOutput => &[OpTag::Any],
OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::LoadFunc => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::Leaf => &[OpTag::DataflowChild],
OpTag::DataflowParent => &[OpTag::Any],
}
Expand Down Expand Up @@ -156,10 +159,11 @@ impl OpTag {
OpTag::Cfg => "Nested control-flow operation",
OpTag::TailLoop => "Tail-recursive loop",
OpTag::Conditional => "Conditional operation",
OpTag::StaticInput => "Node with static input (LoadConst or FnCall)",
OpTag::StaticInput => "Node with static input (LoadConst, LoadFunc, or FnCall)",
OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)",
OpTag::FnCall => "Function call",
OpTag::LoadConst => "Constant load operation",
OpTag::LoadFunc => "Function load operation",
OpTag::Leaf => "Leaf operation",
OpTag::ScopedDefn => "Definitions that can live at global or local scope",
OpTag::DataflowParent => "Operation whose children form a Dataflow Sibling Graph",
Expand Down
3 changes: 2 additions & 1 deletion hugr/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ mod test {

use super::{
AliasDecl, AliasDefn, Call, CallIndirect, Const, CustomOp, FuncDecl, Input, Lift, LoadConstant,
MakeTuple, Noop, Output, Tag, UnpackTuple,
LoadFunction, MakeTuple, Noop, Output, Tag, UnpackTuple,
};
impl_validate_op!(FuncDecl);
impl_validate_op!(AliasDecl);
Expand All @@ -417,6 +417,7 @@ impl_validate_op!(Output);
impl_validate_op!(Const);
impl_validate_op!(Call);
impl_validate_op!(LoadConstant);
impl_validate_op!(LoadFunction);
impl_validate_op!(CallIndirect);
impl_validate_op!(CustomOp);
impl_validate_op!(Noop);
Expand Down
45 changes: 45 additions & 0 deletions specification/schema/hugr_schema_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,47 @@
"title": "LoadConstant",
"type": "object"
},
"LoadFunction": {
"description": "Load a static function in to the local dataflow graph.",
"properties": {
"parent": {
"title": "Parent",
"type": "integer"
},
"input_extensions": {
"$ref": "#/$defs/ExtensionSet"
},
"op": {
"const": "LoadFunction",
"default": "LoadFunction",
"enum": [
"LoadFunction"
],
"title": "Op",
"type": "string"
},
"func_sig": {
"$ref": "#/$defs/PolyFuncType"
},
"type_args": {
"items": {
"$ref": "#/$defs/TypeArg"
},
"title": "Type Args",
"type": "array"
},
"signature": {
"$ref": "#/$defs/FunctionType"
}
},
"required": [
"parent",
"func_sig",
"type_args"
],
"title": "LoadFunction",
"type": "object"
},
"MakeTuple": {
"description": "An operation that packs all its inputs into a tuple.",
"properties": {
Expand Down Expand Up @@ -994,6 +1035,7 @@
"Input": "#/$defs/Input",
"Lift": "#/$defs/Lift",
"LoadConstant": "#/$defs/LoadConstant",
"LoadFunction": "#/$defs/LoadFunction",
"MakeTuple": "#/$defs/MakeTuple",
"Module": "#/$defs/Module",
"Noop": "#/$defs/Noop",
Expand Down Expand Up @@ -1050,6 +1092,9 @@
{
"$ref": "#/$defs/LoadConstant"
},
{
"$ref": "#/$defs/LoadFunction"
},
{
"$ref": "#/$defs/CustomOp"
},
Expand Down
Loading