Skip to content

Commit

Permalink
move registry-based methods to wrapper struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 27, 2023
1 parent 9bc6561 commit 58b6eab
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 106 deletions.
50 changes: 29 additions & 21 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,34 @@ impl SignatureFunc {
SignatureFunc::CustomFunc(func) => func.static_params(),
}
}
pub fn compute_signature(
&self,
def: &OpDef,
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, def, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, def, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Ok(res)
}
}

impl Debug for SignatureFunc {
Expand Down Expand Up @@ -306,27 +334,7 @@ impl OpDef {
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self.signature_func {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, self, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, self, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Ok(res)
self.signature_func.compute_signature(self, args, exts)
}

pub(crate) fn should_serialize_signature(&self) -> bool {
Expand Down
163 changes: 103 additions & 60 deletions src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ use strum::IntoEnumIterator;

use crate::{
ops::{custom::ExtensionOp, LeafOp, OpType},
types::TypeArg,
types::{FunctionType, TypeArg},
Extension,
};

use super::{op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef};
use super::{
op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef,
SignatureError,
};
use delegate::delegate;
use thiserror::Error;

/// Error when definition extension does not match that of the [OpEnum]
#[derive(Debug, Error, PartialEq)]
#[error("Expected extension ID {expected} but found {provided}.")]
Expand All @@ -29,47 +32,49 @@ pub enum OpLoadError<T> {
LoadError(T),
}

trait IntoStaticSt {
fn to_static_str(&self) -> &str;
}

impl<T> IntoStaticSt for T
where
for<'a> &'a T: Into<&'static str>,
{
fn to_static_str(&self) -> &str {
self.into()
}
}
/// A trait that operation sets defined by simple (C-style) enums can implement
/// to simplify interactions with the extension.
/// Relies on `strum_macros::{EnumIter, EnumString, IntoStaticStr}`
pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator {
/// The name of the extension these ops belong to.
const EXTENSION_ID: ExtensionId;

pub trait OpEnum: FromStr + IntoEnumIterator + IntoStaticSt {
// TODO can be removed after rust 1.75
/// Error thrown when loading from string fails.
type LoadError: std::error::Error;
/// Description type.
type Description: ToString;

/// Return the signature (polymorphic function type) of the operation.
fn signature(&self) -> SignatureFunc;
fn def_signature(&self) -> SignatureFunc;

/// Description of the operation.
fn description(&self) -> Self::Description;

/// Edit the opdef before finalising.
/// Any type args which define this operation. Default is no type arguments.
fn type_args(&self) -> Vec<TypeArg> {
vec![]
}

/// Edit the opdef before finalising. By default does nothing.
fn post_opdef(&self, _def: &mut OpDef) {}

/// Name of the operation - derived from strum serialization.
fn name(&self) -> &str {
(*self).into()
self.to_static_str()
}

/// Load an operation from the name of the operation.
fn from_extension_name(op_name: &str) -> Result<Self, Self::LoadError>;

/// Try to load one of the operations of this set from an [OpDef].
fn try_from_op_def(op_def: &OpDef) -> Result<Self, OpLoadError<Self::LoadError>> {
if op_def.extension() != &Self::EXTENSION_ID {
return Err(WrongExtension {
expected: Self::EXTENSION_ID.clone(),
provided: op_def.extension().clone(),
}
.into());
}
Self::from_extension_name(op_def.name()).map_err(OpLoadError::LoadError)
}
fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result<Self, Self::LoadError>;

/// Add an operation to an extension.
fn add_to_extension<'e>(
Expand All @@ -79,15 +84,16 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator {
let def = ext.add_op(
self.name().into(),
self.description().to_string(),
self.signature(),
self.def_signature(),
)?;

self.post_opdef(def);

Ok(def)
}

/// Iterator over all operations in the set.
/// Iterator over all operations in the set. Non-trivial variants will have
/// default values used for the members.
fn all_variants() -> <Self as IntoEnumIterator>::Iterator {
<Self as IntoEnumIterator>::iter()
}
Expand All @@ -100,31 +106,73 @@ pub trait OpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator {
Ok(())
}

/// Try to instantiate a variant from an [OpType]. Default behaviour assumes
/// an [ExtensionOp] and loads from the name.
fn from_optype(op: &OpType) -> Option<Self> {
let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?;
Self::from_op_def(ext.def(), ext.args()).ok()
}

fn to_registered<'r>(
self,
extension_id: ExtensionId,
registry: &'r ExtensionRegistry,
) -> RegisteredEnum<'r, Self> {
RegisteredEnum {
extension_id,
registry,
op_enum: self,
}
}
}

pub struct RegisteredEnum<'r, T> {
/// The name of the extension these ops belong to.
extension_id: ExtensionId,
registry: &'r ExtensionRegistry,
op_enum: T,
}

impl<'a, T: OpEnum> RegisteredEnum<'a, T> {
/// Generate an [OpType].
fn to_optype(
&self,
extension: &Extension,
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Option<OpType> {
let leaf: LeafOp = ExtensionOp::new(extension.get_op(self.name())?.clone(), args, exts)
.ok()?
.into();
pub fn to_optype(&self) -> Option<OpType> {
let leaf: LeafOp = ExtensionOp::new(
self.registry
.get(&self.extension_id)?
.get_op(self.name())?
.clone(),
self.type_args(),
self.registry,
)
.ok()?
.into();

Some(leaf.into())
}

/// Try to instantiate a variant from an [OpType]. Default behaviour assumes
/// an [ExtensionOp] and loads from the name.
fn from_optype(op: &OpType) -> Option<Self> {
let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?;
Self::try_from_op_def(ext.def()).ok()
pub fn function_type(&self) -> Result<FunctionType, SignatureError> {
self.op_enum.def_signature().compute_signature(
self.registry
.get(&self.extension_id)
.expect("should return 'Extension not in registry' error here.")
.get_op(self.name())
.expect("should return 'Op not in extension' error here."),
&self.type_args(),
self.registry,
)
}
delegate! {
to self.op_enum {
pub fn name(&self) -> &str;
pub fn type_args(&self) -> Vec<TypeArg>;
pub fn description(&self) -> T::Description;
}
}
}

#[cfg(test)]
mod test {
use crate::{extension::EMPTY_REG, type_row, types::FunctionType};
use crate::{type_row, types::FunctionType};

use super::*;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
Expand All @@ -136,21 +184,19 @@ mod test {
#[error("Dummy")]
struct DummyError;
impl OpEnum for DummyEnum {
const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("dummy");

type LoadError = DummyError;

type Description = &'static str;

fn signature(&self) -> SignatureFunc {
fn def_signature(&self) -> SignatureFunc {
FunctionType::new_endo(type_row![]).into()
}

fn description(&self) -> Self::Description {
"dummy"
}

fn from_extension_name(_op_name: &str) -> Result<Self, Self::LoadError> {
fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result<Self, Self::LoadError> {
Ok(Self::Dumb)
}
}
Expand All @@ -159,31 +205,28 @@ mod test {
fn test_dummy_enum() {
let o = DummyEnum::Dumb;

let good_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(good_name.clone());
let ext_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(ext_name.clone());

o.add_to_extension(&mut e).unwrap();

assert_eq!(
DummyEnum::try_from_op_def(e.get_op(o.name()).unwrap()).unwrap(),
DummyEnum::from_op_def(e.get_op(o.name()).unwrap(), &[]).unwrap(),
o
);

assert_eq!(
DummyEnum::from_optype(&o.to_optype(&e, &[], &EMPTY_REG).unwrap()).unwrap(),
DummyEnum::from_optype(
&o.clone()
.to_registered(
ext_name,
&ExtensionRegistry::try_new([e.to_owned()]).unwrap()
)
.to_optype()
.unwrap()
)
.unwrap(),
o
);
let bad_name = ExtensionId::new("not_dummy").unwrap();
let mut e = Extension::new(bad_name.clone());

o.add_to_extension(&mut e).unwrap();

assert_eq!(
DummyEnum::try_from_op_def(e.get_op(o.name()).unwrap()),
Err(OpLoadError::WrongExtension(WrongExtension {
expected: good_name,
provided: bad_name
}))
);
}
}
Loading

0 comments on commit 58b6eab

Please sign in to comment.