From d98fb791a5ca5651fba659e0e55da59db451a975 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 9 Apr 2024 14:43:30 +0100 Subject: [PATCH] refactor: Combine ExtensionSolutions (no separate closure) (#884) * `infer::infer_extensions` returns only a combined solution (for previously-open locations), after variables instantiated * `Hugr::infer_extensions` writes (all parts of) the solution into place *and* returns it * `validate_with_extension_closure` left in-place, with test demonstrating usage w/ sub-DFGs * This should open the way (in future PRs) to changing implementation of `infer::infer_extensions` --- hugr/src/extension/infer.rs | 24 ++++----- hugr/src/extension/infer/test.rs | 91 +++++++++++++++++++++++++++++--- hugr/src/hugr.rs | 26 ++++----- hugr/src/hugr/validate/test.rs | 7 ++- 4 files changed, 108 insertions(+), 40 deletions(-) diff --git a/hugr/src/extension/infer.rs b/hugr/src/extension/infer.rs index e08a23e78..0e77098e8 100644 --- a/hugr/src/extension/infer.rs +++ b/hugr/src/extension/infer.rs @@ -31,24 +31,22 @@ use thiserror::Error; /// been inferred for their inputs. pub type ExtensionSolution = HashMap; -/// Infer extensions for a hugr. This is the main API exposed by this module +/// Infer extensions for a hugr. This is the main API exposed by this module. /// -/// Return a tuple of the solutions found for locations on the graph, and a -/// closure: a solution which would be valid if all of the variables in the graph -/// were instantiated to an empty extension set. This is used (by validation) to -/// concretise the extension requirements of the whole hugr. -pub fn infer_extensions( - hugr: &impl HugrView, -) -> Result<(ExtensionSolution, ExtensionSolution), InferExtensionError> { +/// Return all the solutions found for locations on the graph, these can be +/// passed to [`validate_with_extension_closure`] +/// +/// [`validate_with_extension_closure`]: crate::Hugr::validate_with_extension_closure +pub fn infer_extensions(hugr: &impl HugrView) -> Result { let mut ctx = UnificationContext::new(hugr); - let solution = ctx.main_loop()?; + ctx.main_loop()?; ctx.instantiate_variables(); - let closed_solution = ctx.main_loop()?; - let closure: ExtensionSolution = closed_solution + let all_results = ctx.main_loop()?; + let new_results = all_results .into_iter() - .filter(|(node, _)| !solution.contains_key(node)) + .filter(|(n, _sol)| hugr.get_nodetype(*n).input_extensions().is_none()) .collect(); - Ok((solution, closure)) + Ok(new_results) } /// Metavariables don't need much diff --git a/hugr/src/extension/infer/test.rs b/hugr/src/extension/infer/test.rs index 91aa7788b..256480614 100644 --- a/hugr/src/extension/infer/test.rs +++ b/hugr/src/extension/infer/test.rs @@ -15,7 +15,8 @@ use crate::ops::{LeafOp, OpType}; #[cfg(feature = "extension_inference")] use crate::{ builder::test::closed_dfg_root_hugr, - hugr::validate::ValidationError, + extension::prelude::PRELUDE_ID, + hugr::{hugrmut::sealed::HugrMutInternals, validate::ValidationError}, ops::{dataflow::DataflowParent, handle::NodeHandle}, }; @@ -100,13 +101,13 @@ fn from_graph() -> Result<(), Box> { hugr.connect(mult_c, 0, output, 0); - let (_, closure) = infer_extensions(&hugr)?; + let solution = infer_extensions(&hugr)?; let empty = ExtensionSet::new(); let ab = ExtensionSet::from_iter([A, B]); - assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty); - assert_eq!(*closure.get(&(mult_c)).unwrap(), ab); - assert_eq!(*closure.get(&(add_ab)).unwrap(), empty); - assert_eq!(*closure.get(&add_b).unwrap(), ExtensionSet::singleton(&A)); + assert_eq!(*solution.get(&(hugr.root())).unwrap(), empty); + assert_eq!(*solution.get(&(mult_c)).unwrap(), ab); + assert_eq!(*solution.get(&(add_ab)).unwrap(), empty); + assert_eq!(*solution.get(&add_b).unwrap(), ExtensionSet::singleton(&A)); Ok(()) } @@ -249,8 +250,7 @@ fn dangling_src() -> Result<(), Box> { hugr.connect(src, 0, mult, 1); hugr.connect(mult, 0, output, 0); - let closure = hugr.infer_extensions()?; - assert!(closure.is_empty()); + hugr.infer_extensions()?; assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs); assert_eq!( hugr.get_nodetype(mult.node()).io_extensions().unwrap(), @@ -795,6 +795,81 @@ fn test_cfg_loops() -> Result<(), Box> { Ok(()) } +#[test] +#[cfg(feature = "extension_inference")] +fn test_validate_with_closure() -> Result<(), Box> { + fn dfg_hugr_with_exts(e: Option) -> (Hugr, Node, Node) { + let mut h = closed_dfg_root_hugr(FunctionType::new_endo(type_row![QB_T])); + h.replace_op(h.root(), NodeType::new(h.get_optype(h.root()).clone(), e)) + .unwrap(); + let [input, output] = h.get_io(h.root()).unwrap(); + (h, input, output) + } + fn identity_hugr_with_exts(e: Option) -> Hugr { + let (mut h, input, output) = dfg_hugr_with_exts(e); + h.connect(input, 0, output, 0); + h + } + + const EXT_ID: ExtensionId = ExtensionId::new_unchecked("foo"); + + let inner_open = identity_hugr_with_exts(None); + + let inner_prelude = identity_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID))); + + let inner_other = identity_hugr_with_exts(Some(ExtensionSet::singleton(&EXT_ID))); + + // All three can be inferred and validated, without writing solutions in: + for inner in [&inner_open, &inner_prelude, &inner_other] { + assert_matches!( + inner.validate(&PRELUDE_REGISTRY), + Err(ValidationError::ExtensionError(_)) + ); + + let soln = infer_extensions(inner)?; + inner.validate_with_extension_closure(soln, &PRELUDE_REGISTRY)?; + } + + // Helper builds a Hugr with extensions {PRELUDE_ID}, around argument + let build_outer_prelude = |inner: Hugr| -> Hugr { + let (mut h, input, output) = dfg_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID))); + let inner_node = h.insert_hugr(h.root(), inner).new_root; + h.connect(input, 0, inner_node, 0); + h.connect(inner_node, 0, output, 0); + h + }; + + // Building a Hugr around the inner DFG works if the inner DFG is open, + // or has the correct (prelude) extensions: + for inner in [&inner_open, &inner_prelude] { + let mut h = build_outer_prelude(inner.clone()); + h.update_validate(&PRELUDE_REGISTRY)?; + } + + // ...but fails if the inner DFG already has the 'wrong' extensions: + assert_matches!( + build_outer_prelude(inner_other.clone()).update_validate(&PRELUDE_REGISTRY), + Err(ValidationError::CantInfer(_)) + ); + + // If we do inference on the inner Hugr first, this (still) works if the + // inner DFG already had the correct input-extensions: + let mut inner_prelude_inferred = inner_prelude; + inner_prelude_inferred.update_validate(&PRELUDE_REGISTRY)?; + build_outer_prelude(inner_prelude_inferred).update_validate(&PRELUDE_REGISTRY)?; + + // But fails for previously-open inner DFG as inference + // infers an incorrect (empty) solution: + let mut inner_inferred = inner_open; + inner_inferred.update_validate(&PRELUDE_REGISTRY)?; + assert_matches!( + build_outer_prelude(inner_inferred).update_validate(&PRELUDE_REGISTRY), + Err(ValidationError::CantInfer(_)) + ); + + Ok(()) +} + #[test] /// A control flow graph consisting of an entry node and a single block /// which adds a resource and links to both itself and the exit node. diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index c92231804..9a60cf90d 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -8,8 +8,6 @@ pub mod serialize; pub mod validate; pub mod views; -#[cfg(not(feature = "extension_inference"))] -use std::collections::HashMap; use std::collections::VecDeque; use std::iter; @@ -198,29 +196,27 @@ impl Hugr { extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { resolve_extension_ops(self, extension_registry)?; - let closure = self.infer_extensions()?; - self.validate_with_extension_closure(closure, extension_registry)?; + self.infer_extensions()?; + self.validate(extension_registry)?; Ok(()) } /// Infer extension requirements and add new information to `op_types` field + /// (if the "extension_inference" feature is on; otherwise, do nothing) /// /// See [`infer_extensions`] for details on the "closure" value - #[cfg(feature = "extension_inference")] - pub fn infer_extensions(&mut self) -> Result { - let (solution, extension_closure) = infer_extensions(self)?; - self.instantiate_extensions(solution); - Ok(extension_closure) - } - /// Do nothing - this functionality is gated by the feature "extension_inference" - #[cfg(not(feature = "extension_inference"))] - pub fn infer_extensions(&mut self) -> Result { - Ok(HashMap::new()) + pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> { + #[cfg(feature = "extension_inference")] + { + let solution = infer_extensions(self)?; + self.instantiate_extensions(&solution); + } + Ok(()) } #[allow(dead_code)] /// Add extension requirement information to the hugr in place. - fn instantiate_extensions(&mut self, solution: ExtensionSolution) { + fn instantiate_extensions(&mut self, solution: &ExtensionSolution) { // We only care about inferred _input_ extensions, because `NodeType` // uses those to infer the output extensions for (node, input_extensions) in solution.iter() { diff --git a/hugr/src/hugr/validate/test.rs b/hugr/src/hugr/validate/test.rs index 6ecf00699..0e1ecaee0 100644 --- a/hugr/src/hugr/validate/test.rs +++ b/hugr/src/hugr/validate/test.rs @@ -149,14 +149,14 @@ fn children_restrictions() { b.update_validate(&EMPTY_REG), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); - let closure = b.infer_extensions().unwrap(); + b.infer_extensions().unwrap(); b.set_parent(new_def, root); // After moving the previous definition to a valid place, // add an input node to the module subgraph let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![])); assert_matches!( - b.validate_with_extension_closure(closure, &EMPTY_REG), + b.validate(&EMPTY_REG), Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} ); } @@ -608,8 +608,7 @@ mod extension_tests { .unwrap(); // Write Extension annotations into the Hugr while it's still well-formed // enough for us to compute them - let closure = b.infer_extensions().unwrap(); - b.instantiate_extensions(closure); + b.infer_extensions().unwrap(); b.validate(&EMPTY_REG).unwrap(); b.replace_op( copy,