Skip to content

Commit

Permalink
feat: MakeRegisteredOp trait for easier registration
Browse files Browse the repository at this point in the history
especially when static registry references are available
  • Loading branch information
ss2165 committed Nov 30, 2023
1 parent 04fdad0 commit 841d9a2
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 62 deletions.
76 changes: 65 additions & 11 deletions src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub trait MakeExtensionOp: OpName {
}
}

/// Blanket implementation for non-polymorphic operations - no type parameters.
/// Blanket implementation for non-polymorphic operations - [OpDef]s with no type parameters.
impl<T: MakeOpDef> MakeExtensionOp for T {
#[inline]
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
Expand Down Expand Up @@ -187,12 +187,50 @@ impl<T: MakeExtensionOp> RegisteredOp<'_, T> {
}
}

/// Trait for operations that can self report the extension ID they belong to
/// and the registry required to compute their types.
/// Allows conversion to [`ExtensionOp`]
pub trait MakeRegisteredOp: MakeExtensionOp {
/// The ID of the extension this op belongs to.
fn extension_id(&self) -> ExtensionId;
/// A reference to an [ExtensionRegistry] which is sufficient to generate
/// the signature of this op.
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry;

/// Convert this operation in to an [ExtensionOp]. Returns None if the type
/// cannot be computed.
fn to_extension_op(self) -> Option<ExtensionOp>
where
Self: Sized,
{
let registered: RegisteredOp<_> = self.into();
registered.to_extension_op()
}
}

impl<T: MakeRegisteredOp> From<T> for RegisteredOp<'_, T> {
fn from(ext_op: T) -> Self {
let extension_id = ext_op.extension_id();
let registry = ext_op.registry();
ext_op.to_registered(extension_id, registry)
}
}

impl<T: MakeRegisteredOp + MakeExtensionOp> From<T> for OpType {
/// Convert
fn from(ext_op: T) -> Self {
ext_op.to_extension_op().unwrap().into()
}
}

#[cfg(test)]
mod test {
use crate::{type_row, types::FunctionType};
use crate::{const_extension_ids, type_row, types::FunctionType};

use super::*;
use lazy_static::lazy_static;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
enum DummyEnum {
Dumb,
Expand All @@ -207,27 +245,43 @@ mod test {
Ok(Self::Dumb)
}
}
const_extension_ids! {
const EXT_ID: ExtensionId = "DummyExt";
}

lazy_static! {
static ref EXT: Extension = {
let mut e = Extension::new(EXT_ID.clone());
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
e
};
static ref DUMMY_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXT.to_owned()]).unwrap();
}
impl MakeRegisteredOp for DummyEnum {
fn extension_id(&self) -> ExtensionId {
EXT_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&DUMMY_REG
}
}

#[test]
fn test_dummy_enum() {
let o = DummyEnum::Dumb;

let ext_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(ext_name.clone());

o.add_to_extension(&mut e).unwrap();
assert_eq!(
DummyEnum::from_def(e.get_op(&o.name()).unwrap()).unwrap(),
DummyEnum::from_def(EXT.get_op(&o.name()).unwrap()).unwrap(),
o
);

let registry = ExtensionRegistry::try_new([e.to_owned()]).unwrap();
let registered = o.clone().to_registered(ext_name, &registry);
assert_eq!(
DummyEnum::from_optype(&registered.to_extension_op().unwrap().into()).unwrap(),
DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
o
);

let registered: RegisteredOp<_> = o.clone().into();
assert_eq!(registered.to_inner(), o);
}
}
8 changes: 4 additions & 4 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, Const, LeafOp, OpType};
use crate::std_extensions::logic;
use crate::std_extensions::logic::test::{and_op, not_op, or_op};
use crate::std_extensions::logic::test::{and_op, or_op};
use crate::std_extensions::logic::{self, NotOp};
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow};
use crate::values::Value;
Expand Down Expand Up @@ -602,8 +602,8 @@ fn dfg_with_cycles() -> Result<(), HugrError> {
));
let [input, output] = h.get_io(h.root()).unwrap();
let or = h.add_node_with_parent(h.root(), or_op())?;
let not1 = h.add_node_with_parent(h.root(), not_op())?;
let not2 = h.add_node_with_parent(h.root(), not_op())?;
let not1 = h.add_node_with_parent(h.root(), NotOp)?;
let not2 = h.add_node_with_parent(h.root(), NotOp)?;
h.connect(input, 0, or, 0)?;
h.connect(or, 0, not1, 0)?;
h.connect(not1, 0, or, 1)?;
Expand Down
13 changes: 5 additions & 8 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ mod tests {
hugr::views::{HierarchyView, SiblingGraph},
hugr::HugrMut,
ops::handle::{DfgID, FuncID, NodeHandle},
std_extensions::logic::test::{and_op, not_op},
std_extensions::logic::{test::and_op, NotOp},
type_row,
};

Expand Down Expand Up @@ -742,9 +742,9 @@ mod tests {
let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).into())?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?;
let outs2 = dfg.add_dataflow_op(not_op(), outs1.outputs())?;
let outs3 = dfg.add_dataflow_op(not_op(), outs2.outputs())?;
let outs1 = dfg.add_dataflow_op(NotOp, dfg.input_wires())?;
let outs2 = dfg.add_dataflow_op(NotOp, outs1.outputs())?;
let outs3 = dfg.add_dataflow_op(NotOp, outs2.outputs())?;
dfg.finish_with_outputs(outs3.outputs())?
};
let hugr = mod_builder
Expand Down Expand Up @@ -976,10 +976,7 @@ mod tests {
let mut builder =
DFGBuilder::new(FunctionType::new(one_bit.clone(), two_bit.clone())).unwrap();
let inw = builder.input_wires().exactly_one().unwrap();
let outw1 = builder
.add_dataflow_op(not_op(), [inw])
.unwrap()
.out_wire(0);
let outw1 = builder.add_dataflow_op(NotOp, [inw]).unwrap().out_wire(0);
let outw2 = builder
.add_dataflow_op(and_op(), [inw, outw1])
.unwrap()
Expand Down
7 changes: 3 additions & 4 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
ops::handle::{DataflowOpID, NodeHandle},
std_extensions::logic::NotOp,
type_row,
types::FunctionType,
utils::test_quantum_extension::cx_gate,
Expand Down Expand Up @@ -102,7 +103,6 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflow
fn value_types() {
use crate::builder::Container;
use crate::extension::prelude::BOOL_T;
use crate::std_extensions::logic::test::not_op;
use crate::utils::test_quantum_extension::h_gate;
use itertools::Itertools;
let mut dfg = DFGBuilder::new(FunctionType::new(
Expand All @@ -113,7 +113,7 @@ fn value_types() {

let [q, b] = dfg.input_wires_arr();
let n1 = dfg.add_dataflow_op(h_gate(), [q]).unwrap();
let n2 = dfg.add_dataflow_op(not_op(), [b]).unwrap();
let n2 = dfg.add_dataflow_op(NotOp, [b]).unwrap();
dfg.add_other_wire(n1.node(), n2.node()).unwrap();
let h = dfg
.finish_prelude_hugr_with_outputs([n2.out_wire(0), n1.out_wire(0)])
Expand Down Expand Up @@ -158,7 +158,6 @@ fn test_dataflow_ports_only() {
use crate::builder::DataflowSubContainer;
use crate::extension::{prelude::BOOL_T, PRELUDE_REGISTRY};
use crate::hugr::views::PortIterator;
use crate::std_extensions::logic::test::not_op;
use itertools::Itertools;
let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap();
let local_and = {
Expand All @@ -173,7 +172,7 @@ fn test_dataflow_ports_only() {
};
let [in_bool] = dfg.input_wires_arr();

let not = dfg.add_dataflow_op(not_op(), [in_bool]).unwrap();
let not = dfg.add_dataflow_op(NotOp, [in_bool]).unwrap();
let call = dfg
.call(
local_and.handle(),
Expand Down
68 changes: 33 additions & 35 deletions src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr};
use crate::{
extension::{
prelude::BOOL_T,
simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, OpLoadError},
ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc,
simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError},
ExtensionId, ExtensionRegistry, OpDef, SignatureError, SignatureFromArgs, SignatureFunc,
},
ops::{self, custom::ExtensionOp, OpName},
type_row,
Expand Down Expand Up @@ -138,29 +138,41 @@ fn extension() -> Extension {
lazy_static! {
/// Reference to the logic Extension.
pub static ref EXTENSION: Extension = extension();
static ref LOGIC_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap();
}

impl MakeRegisteredOp for ConcreteLogicOp {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&LOGIC_REG
}
}

impl MakeRegisteredOp for NotOp {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&LOGIC_REG
}
}

#[cfg(test)]
pub(crate) mod test {
use super::{
extension, ConcreteLogicOp, NaryLogic, NotOp, EXTENSION, EXTENSION_ID, FALSE_NAME,
TRUE_NAME,
};
use super::{extension, ConcreteLogicOp, NaryLogic, FALSE_NAME, TRUE_NAME};
use crate::{
extension::{
prelude::BOOL_T,
simple_op::{MakeExtensionOp, MakeOpDef},
ExtensionRegistry,
},
ops::{custom::ExtensionOp, OpName},
extension::{prelude::BOOL_T, simple_op::MakeOpDef},
ops::OpName,
Extension,
};
use lazy_static::lazy_static;

use strum::IntoEnumIterator;
lazy_static! {
pub(crate) static ref LOGIC_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap();
}

#[test]
fn test_logic_extension() {
let r: Extension = extension();
Expand All @@ -187,27 +199,13 @@ pub(crate) mod test {
}
}

/// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`]
pub(crate) fn and_op() -> ExtensionOp {
/// Generate a logic extension "and" operation over [`crate::prelude::BOOL_T`]
pub(crate) fn and_op() -> ConcreteLogicOp {
ConcreteLogicOp(NaryLogic::And, 2)
.to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG)
.to_extension_op()
.unwrap()
}

/// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`]
pub(crate) fn or_op() -> ExtensionOp {
/// Generate a logic extension "or" operation over [`crate::prelude::BOOL_T`]
pub(crate) fn or_op() -> ConcreteLogicOp {
ConcreteLogicOp(NaryLogic::Or, 2)
.to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG)
.to_extension_op()
.unwrap()
}

/// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`]
pub(crate) fn not_op() -> ExtensionOp {
NotOp
.to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG)
.to_extension_op()
.unwrap()
}
}

0 comments on commit 841d9a2

Please sign in to comment.