diff --git a/src/extension/simple_op.rs b/src/extension/simple_op.rs index bd9f280dd..ff4b62778 100644 --- a/src/extension/simple_op.rs +++ b/src/extension/simple_op.rs @@ -119,7 +119,7 @@ pub trait MakeExtensionOp: OpName { } } -/// Blanket implementation for non-polymorphic operations - no type parameters. +/// Blanket implementation for non-polymorphic operations - [OpDef]s with no type parameters. impl MakeExtensionOp for T { #[inline] fn from_extension_op(ext_op: &ExtensionOp) -> Result @@ -187,12 +187,50 @@ impl RegisteredOp<'_, T> { } } +/// Trait for operations that can self report the extension ID they belong to +/// and the registry required to compute their types. +/// Allows conversion to [`ExtensionOp`] +pub trait MakeRegisteredOp: MakeExtensionOp { + /// The ID of the extension this op belongs to. + fn extension_id(&self) -> ExtensionId; + /// A reference to an [ExtensionRegistry] which is sufficient to generate + /// the signature of this op. + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry; + + /// Convert this operation in to an [ExtensionOp]. Returns None if the type + /// cannot be computed. + fn to_extension_op(self) -> Option + where + Self: Sized, + { + let registered: RegisteredOp<_> = self.into(); + registered.to_extension_op() + } +} + +impl From for RegisteredOp<'_, T> { + fn from(ext_op: T) -> Self { + let extension_id = ext_op.extension_id(); + let registry = ext_op.registry(); + ext_op.to_registered(extension_id, registry) + } +} + +impl From for OpType { + /// Convert + fn from(ext_op: T) -> Self { + ext_op.to_extension_op().unwrap().into() + } +} + #[cfg(test)] mod test { - use crate::{type_row, types::FunctionType}; + use crate::{const_extension_ids, type_row, types::FunctionType}; use super::*; + use lazy_static::lazy_static; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] enum DummyEnum { Dumb, @@ -207,27 +245,43 @@ mod test { Ok(Self::Dumb) } } + const_extension_ids! { + const EXT_ID: ExtensionId = "DummyExt"; + } + + lazy_static! { + static ref EXT: Extension = { + let mut e = Extension::new(EXT_ID.clone()); + DummyEnum::Dumb.add_to_extension(&mut e).unwrap(); + e + }; + static ref DUMMY_REG: ExtensionRegistry = + ExtensionRegistry::try_new([EXT.to_owned()]).unwrap(); + } + impl MakeRegisteredOp for DummyEnum { + fn extension_id(&self) -> ExtensionId { + EXT_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &DUMMY_REG + } + } #[test] fn test_dummy_enum() { let o = DummyEnum::Dumb; - let ext_name = ExtensionId::new("dummy").unwrap(); - let mut e = Extension::new(ext_name.clone()); - - o.add_to_extension(&mut e).unwrap(); assert_eq!( - DummyEnum::from_def(e.get_op(&o.name()).unwrap()).unwrap(), + DummyEnum::from_def(EXT.get_op(&o.name()).unwrap()).unwrap(), o ); - let registry = ExtensionRegistry::try_new([e.to_owned()]).unwrap(); - let registered = o.clone().to_registered(ext_name, ®istry); assert_eq!( - DummyEnum::from_optype(®istered.to_extension_op().unwrap().into()).unwrap(), + DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(), o ); - + let registered: RegisteredOp<_> = o.clone().into(); assert_eq!(registered.to_inner(), o); } } diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 424f5f049..8b1545049 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -15,8 +15,8 @@ use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, Const, LeafOp, OpType}; -use crate::std_extensions::logic; -use crate::std_extensions::logic::test::{and_op, not_op, or_op}; +use crate::std_extensions::logic::test::{and_op, or_op}; +use crate::std_extensions::logic::{self, NotOp}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; use crate::values::Value; @@ -602,8 +602,8 @@ fn dfg_with_cycles() -> Result<(), HugrError> { )); let [input, output] = h.get_io(h.root()).unwrap(); let or = h.add_node_with_parent(h.root(), or_op())?; - let not1 = h.add_node_with_parent(h.root(), not_op())?; - let not2 = h.add_node_with_parent(h.root(), not_op())?; + let not1 = h.add_node_with_parent(h.root(), NotOp)?; + let not2 = h.add_node_with_parent(h.root(), NotOp)?; h.connect(input, 0, or, 0)?; h.connect(or, 0, not1, 0)?; h.connect(not1, 0, or, 1)?; diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 19b2b9503..461083378 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -689,7 +689,7 @@ mod tests { hugr::views::{HierarchyView, SiblingGraph}, hugr::HugrMut, ops::handle::{DfgID, FuncID, NodeHandle}, - std_extensions::logic::test::{and_op, not_op}, + std_extensions::logic::{test::and_op, NotOp}, type_row, }; @@ -742,9 +742,9 @@ mod tests { let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; - let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?; - let outs2 = dfg.add_dataflow_op(not_op(), outs1.outputs())?; - let outs3 = dfg.add_dataflow_op(not_op(), outs2.outputs())?; + let outs1 = dfg.add_dataflow_op(NotOp, dfg.input_wires())?; + let outs2 = dfg.add_dataflow_op(NotOp, outs1.outputs())?; + let outs3 = dfg.add_dataflow_op(NotOp, outs2.outputs())?; dfg.finish_with_outputs(outs3.outputs())? }; let hugr = mod_builder @@ -976,10 +976,7 @@ mod tests { let mut builder = DFGBuilder::new(FunctionType::new(one_bit.clone(), two_bit.clone())).unwrap(); let inw = builder.input_wires().exactly_one().unwrap(); - let outw1 = builder - .add_dataflow_op(not_op(), [inw]) - .unwrap() - .out_wire(0); + let outw1 = builder.add_dataflow_op(NotOp, [inw]).unwrap().out_wire(0); let outw2 = builder .add_dataflow_op(and_op(), [inw, outw1]) .unwrap() diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 6d81ff52a..a2a3274b9 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -102,9 +102,10 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle ConcreteLogicOp { + ConcreteLogicOp(self, n) + } +} impl OpName for ConcreteLogicOp { fn name(&self) -> smol_str::SmolStr { self.0.name() @@ -59,14 +67,10 @@ impl OpName for ConcreteLogicOp { impl MakeExtensionOp for ConcreteLogicOp { fn from_extension_op(ext_op: &ExtensionOp) -> Result { let def: NaryLogic = NaryLogic::from_def(ext_op.def())?; - Ok(match def { - NaryLogic::And | NaryLogic::Or => { - let [TypeArg::BoundedNat { n }] = *ext_op.args() else { - return Err(SignatureError::InvalidTypeArgs.into()); - }; - Self(def, n) - } - }) + let [TypeArg::BoundedNat { n }] = *ext_op.args() else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + Ok(Self(def, n)) } fn type_args(&self) -> Vec { @@ -142,29 +146,45 @@ fn extension() -> Extension { lazy_static! { /// Reference to the logic Extension. pub static ref EXTENSION: Extension = extension(); + /// Registry required to validate logic extension. + pub static ref LOGIC_REG: ExtensionRegistry = + ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); +} + +impl MakeRegisteredOp for ConcreteLogicOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &LOGIC_REG + } +} + +impl MakeRegisteredOp for NotOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &LOGIC_REG + } } #[cfg(test)] pub(crate) mod test { - use super::{ - extension, ConcreteLogicOp, NaryLogic, NotOp, EXTENSION, EXTENSION_ID, FALSE_NAME, - TRUE_NAME, - }; + use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME}; use crate::{ extension::{ prelude::BOOL_T, - simple_op::{MakeExtensionOp, MakeOpDef}, - ExtensionRegistry, + simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, }, - ops::{custom::ExtensionOp, OpName}, + ops::OpName, Extension, }; - use lazy_static::lazy_static; + use strum::IntoEnumIterator; - lazy_static! { - pub(crate) static ref LOGIC_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); - } + #[test] fn test_logic_extension() { let r: Extension = extension(); @@ -179,6 +199,17 @@ pub(crate) mod test { } } + #[test] + fn test_conversions() { + for def in [NaryLogic::And, NaryLogic::Or] { + let o = def.with_n_inputs(3); + let ext_op = o.clone().to_extension_op().unwrap(); + assert_eq!(ConcreteLogicOp::from_extension_op(&ext_op).unwrap(), o); + } + + NotOp::from_extension_op(&NotOp.to_extension_op().unwrap()).unwrap(); + } + #[test] fn test_values() { let r: Extension = extension(); @@ -191,27 +222,13 @@ pub(crate) mod test { } } - /// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn and_op() -> ExtensionOp { - ConcreteLogicOp(NaryLogic::And, 2) - .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) - .to_extension_op() - .unwrap() - } - - /// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn or_op() -> ExtensionOp { - ConcreteLogicOp(NaryLogic::Or, 2) - .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) - .to_extension_op() - .unwrap() + /// Generate a logic extension "and" operation over [`crate::prelude::BOOL_T`] + pub(crate) fn and_op() -> ConcreteLogicOp { + NaryLogic::And.with_n_inputs(2) } - /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn not_op() -> ExtensionOp { - NotOp - .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) - .to_extension_op() - .unwrap() + /// Generate a logic extension "or" operation over [`crate::prelude::BOOL_T`] + pub(crate) fn or_op() -> ConcreteLogicOp { + NaryLogic::Or.with_n_inputs(2) } }