Skip to content

Commit

Permalink
feat: Instantiate inferred extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Aug 29, 2023
1 parent 95c0b56 commit 8218e9c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 13 deletions.
23 changes: 15 additions & 8 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,21 +857,28 @@ mod test {
let [w] = mult.outputs_arr();

builder.set_outputs([w])?;
let hugr = builder.base;
// TODO: when we put new extensions onto the graph after inference, we
// can call `finish_hugr` and just look at the graph
let (solution, extra) = infer_extensions(&hugr)?;
assert!(extra.is_empty());
let mut hugr = builder.base;
let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
assert_eq!(
*solution.get(&(src.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(src.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Incoming)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.input_extensions,
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
Ok(())
Expand Down
81 changes: 76 additions & 5 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl Hugr {
rw.apply(self)
}

/// Infer extension requirements
/// Infer extension requirements and add new information to `op_types` field
pub fn infer_extensions(
&mut self,
) -> Result<HashMap<(Node, Direction), ExtensionSet>, InferExtensionError> {
Expand All @@ -202,9 +202,22 @@ impl Hugr {
Ok(extension_closure)
}

/// TODO: Write this
fn instantiate_extensions(&mut self, _solution: ExtensionSolution) {
//todo!()
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for ((node, _), input_extensions) in solution
.iter()
.filter(|((_, dir), _)| *dir == Direction::Incoming)
{
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
match nodetype.signature() {
None => nodetype.input_extensions = Some(input_extensions.clone()),
Some(existing_ext_reqs) => {
debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions)
}
}
}
}
}

Expand Down Expand Up @@ -366,7 +379,14 @@ impl From<HugrError> for PyErr {

#[cfg(test)]
mod test {
use super::Hugr;
use super::{Hugr, HugrView, NodeType};
use crate::extension::ExtensionSet;
use crate::hugr::hugrmut::HugrInternalsMut;
use crate::ops;
use crate::type_row;
use crate::types::{FunctionType, Type};

use std::error::Error;

#[test]
fn impls_send_and_sync() {
Expand All @@ -385,4 +405,55 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
const BIT: Type = crate::extension::prelude::USIZE_T;
let r = ExtensionSet::singleton(&"R".into());

let root = NodeType::pure(ops::DFG {
signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r),
});
let mut hugr = Hugr::new(root);
let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
types: type_row![BIT],
}),
)?;
let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
types: type_row![BIT],
}),
)?;
let lift = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".into(),
}),
)?;
hugr.connect(input, 0, lift, 0)?;
hugr.connect(lift, 0, output, 0)?;
hugr.infer_extensions()?;

assert_eq!(
hugr.op_types
.get(lift.index)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.op_types
.get(output.index)
.signature()
.unwrap()
.input_extensions,
r
);
Ok(())
}
}

0 comments on commit 8218e9c

Please sign in to comment.