From 58b6eabc0d03251f11690711405c6e6f327ccc9f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 27 Nov 2023 12:29:42 +0000 Subject: [PATCH] move registry-based methods to wrapper struct --- src/extension/op_def.rs | 50 ++++++----- src/extension/simple_op.rs | 163 +++++++++++++++++++++++------------- src/std_extensions/logic.rs | 83 ++++++++++++------ 3 files changed, 190 insertions(+), 106 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 848065d7a..246be5af2 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -211,6 +211,34 @@ impl SignatureFunc { SignatureFunc::CustomFunc(func) => func.static_params(), } } + pub fn compute_signature( + &self, + def: &OpDef, + args: &[TypeArg], + exts: &ExtensionRegistry, + ) -> Result { + let temp: PolyFuncType; + let (pf, args) = match &self { + SignatureFunc::TypeScheme(custom) => { + custom.validate.validate(args, def, exts)?; + (&custom.poly_func, args) + } + SignatureFunc::CustomFunc(func) => { + let static_params = func.static_params(); + let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); + + check_type_args(static_args, static_params)?; + temp = func.compute_signature(static_args, def, exts)?; + (&temp, other_args) + } + }; + + let res = pf.instantiate(args, exts)?; + // TODO bring this assert back once resource inference is done? + // https://github.com/CQCL-DEV/hugr/issues/425 + // assert!(res.contains(self.extension())); + Ok(res) + } } impl Debug for SignatureFunc { @@ -306,27 +334,7 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - let temp: PolyFuncType; - let (pf, args) = match &self.signature_func { - SignatureFunc::TypeScheme(custom) => { - custom.validate.validate(args, self, exts)?; - (&custom.poly_func, args) - } - SignatureFunc::CustomFunc(func) => { - let static_params = func.static_params(); - let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); - - check_type_args(static_args, static_params)?; - temp = func.compute_signature(static_args, self, exts)?; - (&temp, other_args) - } - }; - - let res = pf.instantiate(args, exts)?; - // TODO bring this assert back once resource inference is done? - // https://github.com/CQCL-DEV/hugr/issues/425 - // assert!(res.contains(self.extension())); - Ok(res) + self.signature_func.compute_signature(self, args, exts) } pub(crate) fn should_serialize_signature(&self) -> bool { diff --git a/src/extension/simple_op.rs b/src/extension/simple_op.rs index 1363e621b..72649def2 100644 --- a/src/extension/simple_op.rs +++ b/src/extension/simple_op.rs @@ -5,13 +5,16 @@ use strum::IntoEnumIterator; use crate::{ ops::{custom::ExtensionOp, LeafOp, OpType}, - types::TypeArg, + types::{FunctionType, TypeArg}, Extension, }; -use super::{op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef}; +use super::{ + op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef, + SignatureError, +}; +use delegate::delegate; use thiserror::Error; - /// Error when definition extension does not match that of the [OpEnum] #[derive(Debug, Error, PartialEq)] #[error("Expected extension ID {expected} but found {provided}.")] @@ -29,13 +32,22 @@ pub enum OpLoadError { LoadError(T), } +trait IntoStaticSt { + fn to_static_str(&self) -> &str; +} + +impl IntoStaticSt for T +where + for<'a> &'a T: Into<&'static str>, +{ + fn to_static_str(&self) -> &str { + self.into() + } +} /// A trait that operation sets defined by simple (C-style) enums can implement /// to simplify interactions with the extension. /// Relies on `strum_macros::{EnumIter, EnumString, IntoStaticStr}` -pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { - /// The name of the extension these ops belong to. - const EXTENSION_ID: ExtensionId; - +pub trait OpEnum: FromStr + IntoEnumIterator + IntoStaticSt { // TODO can be removed after rust 1.75 /// Error thrown when loading from string fails. type LoadError: std::error::Error; @@ -43,33 +55,26 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { type Description: ToString; /// Return the signature (polymorphic function type) of the operation. - fn signature(&self) -> SignatureFunc; + fn def_signature(&self) -> SignatureFunc; /// Description of the operation. fn description(&self) -> Self::Description; - /// Edit the opdef before finalising. + /// Any type args which define this operation. Default is no type arguments. + fn type_args(&self) -> Vec { + vec![] + } + + /// Edit the opdef before finalising. By default does nothing. fn post_opdef(&self, _def: &mut OpDef) {} /// Name of the operation - derived from strum serialization. fn name(&self) -> &str { - (*self).into() + self.to_static_str() } - /// Load an operation from the name of the operation. - fn from_extension_name(op_name: &str) -> Result; - /// Try to load one of the operations of this set from an [OpDef]. - fn try_from_op_def(op_def: &OpDef) -> Result> { - if op_def.extension() != &Self::EXTENSION_ID { - return Err(WrongExtension { - expected: Self::EXTENSION_ID.clone(), - provided: op_def.extension().clone(), - } - .into()); - } - Self::from_extension_name(op_def.name()).map_err(OpLoadError::LoadError) - } + fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result; /// Add an operation to an extension. fn add_to_extension<'e>( @@ -79,7 +84,7 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { let def = ext.add_op( self.name().into(), self.description().to_string(), - self.signature(), + self.def_signature(), )?; self.post_opdef(def); @@ -87,7 +92,8 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { Ok(def) } - /// Iterator over all operations in the set. + /// Iterator over all operations in the set. Non-trivial variants will have + /// default values used for the members. fn all_variants() -> ::Iterator { ::iter() } @@ -100,31 +106,73 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { Ok(()) } + /// Try to instantiate a variant from an [OpType]. Default behaviour assumes + /// an [ExtensionOp] and loads from the name. + fn from_optype(op: &OpType) -> Option { + let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?; + Self::from_op_def(ext.def(), ext.args()).ok() + } + + fn to_registered<'r>( + self, + extension_id: ExtensionId, + registry: &'r ExtensionRegistry, + ) -> RegisteredEnum<'r, Self> { + RegisteredEnum { + extension_id, + registry, + op_enum: self, + } + } +} + +pub struct RegisteredEnum<'r, T> { + /// The name of the extension these ops belong to. + extension_id: ExtensionId, + registry: &'r ExtensionRegistry, + op_enum: T, +} + +impl<'a, T: OpEnum> RegisteredEnum<'a, T> { /// Generate an [OpType]. - fn to_optype( - &self, - extension: &Extension, - args: &[TypeArg], - exts: &ExtensionRegistry, - ) -> Option { - let leaf: LeafOp = ExtensionOp::new(extension.get_op(self.name())?.clone(), args, exts) - .ok()? - .into(); + pub fn to_optype(&self) -> Option { + let leaf: LeafOp = ExtensionOp::new( + self.registry + .get(&self.extension_id)? + .get_op(self.name())? + .clone(), + self.type_args(), + self.registry, + ) + .ok()? + .into(); Some(leaf.into()) } - /// Try to instantiate a variant from an [OpType]. Default behaviour assumes - /// an [ExtensionOp] and loads from the name. - fn from_optype(op: &OpType) -> Option { - let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?; - Self::try_from_op_def(ext.def()).ok() + pub fn function_type(&self) -> Result { + self.op_enum.def_signature().compute_signature( + self.registry + .get(&self.extension_id) + .expect("should return 'Extension not in registry' error here.") + .get_op(self.name()) + .expect("should return 'Op not in extension' error here."), + &self.type_args(), + self.registry, + ) + } + delegate! { + to self.op_enum { + pub fn name(&self) -> &str; + pub fn type_args(&self) -> Vec; + pub fn description(&self) -> T::Description; + } } } #[cfg(test)] mod test { - use crate::{extension::EMPTY_REG, type_row, types::FunctionType}; + use crate::{type_row, types::FunctionType}; use super::*; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; @@ -136,13 +184,11 @@ mod test { #[error("Dummy")] struct DummyError; impl OpEnum for DummyEnum { - const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("dummy"); - type LoadError = DummyError; type Description = &'static str; - fn signature(&self) -> SignatureFunc { + fn def_signature(&self) -> SignatureFunc { FunctionType::new_endo(type_row![]).into() } @@ -150,7 +196,7 @@ mod test { "dummy" } - fn from_extension_name(_op_name: &str) -> Result { + fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result { Ok(Self::Dumb) } } @@ -159,31 +205,28 @@ mod test { fn test_dummy_enum() { let o = DummyEnum::Dumb; - let good_name = ExtensionId::new("dummy").unwrap(); - let mut e = Extension::new(good_name.clone()); + let ext_name = ExtensionId::new("dummy").unwrap(); + let mut e = Extension::new(ext_name.clone()); o.add_to_extension(&mut e).unwrap(); assert_eq!( - DummyEnum::try_from_op_def(e.get_op(o.name()).unwrap()).unwrap(), + DummyEnum::from_op_def(e.get_op(o.name()).unwrap(), &[]).unwrap(), o ); assert_eq!( - DummyEnum::from_optype(&o.to_optype(&e, &[], &EMPTY_REG).unwrap()).unwrap(), + DummyEnum::from_optype( + &o.clone() + .to_registered( + ext_name, + &ExtensionRegistry::try_new([e.to_owned()]).unwrap() + ) + .to_optype() + .unwrap() + ) + .unwrap(), o ); - let bad_name = ExtensionId::new("not_dummy").unwrap(); - let mut e = Extension::new(bad_name.clone()); - - o.add_to_extension(&mut e).unwrap(); - - assert_eq!( - DummyEnum::try_from_op_def(e.get_op(o.name()).unwrap()), - Err(OpLoadError::WrongExtension(WrongExtension { - expected: good_name, - provided: bad_name - })) - ); } } diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 2738d1455..b9c923934 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -4,7 +4,7 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::{ extension::{ - prelude::BOOL_T, simple_op::OpEnum, ExtensionId, SignatureError, SignatureFromArgs, + prelude::BOOL_T, simple_op::OpEnum, ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc, }, ops, type_row, @@ -23,42 +23,62 @@ pub const FALSE_NAME: &str = "FALSE"; pub const TRUE_NAME: &str = "TRUE"; /// Logic extension operations. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs)] pub enum LogicOp { - And, - Or, + And(u64), + Or(u64), Not, } /// Error in trying to load logic operation. #[derive(Debug, Error)] -#[error("Not a logic extension operation.")] -pub struct NotLogicOp; +pub enum LogicOpLoadError { + #[error("Not a logic extension operation.")] + NotLogicOp, + #[error("Type args invalid: {0}.")] + InvalidArgs(#[from] SignatureError), +} impl OpEnum for LogicOp { - const EXTENSION_ID: ExtensionId = EXTENSION_ID; - type LoadError = NotLogicOp; + type LoadError = LogicOpLoadError; type Description = &'static str; - fn from_extension_name(op_name: &str) -> Result { - Self::from_str(op_name).map_err(|_| NotLogicOp) + fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result { + let mut out = Self::from_str(op_def.name()).map_err(|_| LogicOpLoadError::NotLogicOp)?; + match &mut out { + LogicOp::And(i) | LogicOp::Or(i) => { + let [TypeArg::BoundedNat { n }] = *args else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + *i = n; + } + LogicOp::Not => (), + } + + Ok(out) } - fn signature(&self) -> SignatureFunc { + fn def_signature(&self) -> SignatureFunc { match self { - LogicOp::Or | LogicOp::And => logic_op_sig().into(), + LogicOp::Or(_) | LogicOp::And(_) => logic_op_sig().into(), LogicOp::Not => FunctionType::new_endo(type_row![BOOL_T]).into(), } } fn description(&self) -> &'static str { match self { - LogicOp::And => "logical 'and'", - LogicOp::Or => "logical 'or'", + LogicOp::And(_) => "logical 'and'", + LogicOp::Or(_) => "logical 'or'", LogicOp::Not => "logical 'not'", } } + fn type_args(&self) -> Vec { + match self { + LogicOp::And(n) | LogicOp::Or(n) => vec![TypeArg::BoundedNat { n: *n }], + LogicOp::Not => vec![], + } + } } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic"); @@ -107,15 +127,18 @@ lazy_static! { #[cfg(test)] pub(crate) mod test { + use super::{extension, LogicOp, EXTENSION, EXTENSION_ID, FALSE_NAME, TRUE_NAME}; use crate::{ - extension::{prelude::BOOL_T, simple_op::OpEnum, EMPTY_REG}, + extension::{prelude::BOOL_T, simple_op::OpEnum, ExtensionRegistry}, ops::OpType, - types::type_param::TypeArg, + types::TypeArg, Extension, }; - - use super::{extension, LogicOp, EXTENSION, FALSE_NAME, TRUE_NAME}; - + use lazy_static::lazy_static; + lazy_static! { + pub(crate) static ref LOGIC_REG: ExtensionRegistry = + ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); + } #[test] fn test_logic_extension() { let r: Extension = extension(); @@ -124,7 +147,12 @@ pub(crate) mod test { for op in LogicOp::all_variants() { assert_eq!( - LogicOp::try_from_op_def(r.get_op(op.name()).unwrap()).unwrap(), + LogicOp::from_op_def( + r.get_op(op.name()).unwrap(), + // `all_variants` will set default type arg values. + &[TypeArg::BoundedNat { n: 0 }] + ) + .unwrap(), op ); } @@ -144,20 +172,25 @@ pub(crate) mod test { /// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`] pub(crate) fn and_op() -> OpType { - LogicOp::And - .to_optype(&EXTENSION, &[TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + LogicOp::And(2) + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_optype() .unwrap() } /// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`] pub(crate) fn or_op() -> OpType { - LogicOp::Or - .to_optype(&EXTENSION, &[TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + LogicOp::Or(2) + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_optype() .unwrap() } /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] pub(crate) fn not_op() -> OpType { - LogicOp::Not.to_optype(&EXTENSION, &[], &EMPTY_REG).unwrap() + LogicOp::Not + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_optype() + .unwrap() } }