diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 030193c6a..d1663014b 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -6,7 +6,7 @@ use super::{ }; use crate::ops::handle::NodeHandle; -use crate::ops::{self, dataflow::DataflowParent, BasicBlock, OpType}; +use crate::ops::{self, dataflow::DataflowParent, Exit, OpType, DFB}; use crate::{ extension::{ExtensionRegistry, ExtensionSet}, types::FunctionType, @@ -86,7 +86,7 @@ impl + AsRef> CFGBuilder { output: TypeRow, ) -> Result { let n_out_wires = output.len(); - let exit_block_type = OpType::BasicBlock(BasicBlock::Exit { + let exit_block_type = OpType::Exit(Exit { cfg_outputs: output, }); let exit_node = base @@ -102,7 +102,7 @@ impl + AsRef> CFGBuilder { }) } - /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for a non-entry [`DFB`] child graph with `inputs` /// and `outputs` and the variants of the branching TupleSum value /// specified by `tuple_sum_rows`. /// @@ -134,7 +134,7 @@ impl + AsRef> CFGBuilder { entry: bool, ) -> Result, BuildError> { let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); - let op = OpType::BasicBlock(BasicBlock::DFB { + let op = OpType::DFB(DFB { inputs: inputs.clone(), other_outputs: other_outputs.clone(), tuple_sum_rows: tuple_sum_rows.clone(), @@ -153,7 +153,7 @@ impl + AsRef> CFGBuilder { BlockBuilder::create(self.hugr_mut(), block_n) } - /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for a non-entry [`DFB`] child graph with `inputs` /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors @@ -172,7 +172,7 @@ impl + AsRef> CFGBuilder { ) } - /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for the entry [`DFB`] child graph with `inputs` /// and `outputs` and the variants of the branching TupleSum value /// specified by `tuple_sum_rows`. /// @@ -192,7 +192,7 @@ impl + AsRef> CFGBuilder { self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true) } - /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for the entry [`DFB`] child graph with `inputs` /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors @@ -229,7 +229,7 @@ impl + AsRef> CFGBuilder { } } -/// Builder for a [`BasicBlock::DFB`] child graph. +/// Builder for a [`DFB`] child graph. pub type BlockBuilder = DFGWrapper; impl + AsRef> BlockBuilder { @@ -243,7 +243,7 @@ impl + AsRef> BlockBuilder { Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs)) } fn create(base: B, block_n: Node) -> Result { - let block_op = base.get_optype(block_n).as_basic_block().unwrap(); + let block_op = base.get_optype(block_n).as_dfb().unwrap(); let signature = block_op.inner_signature(); let inp_ex = base .as_ref() @@ -269,7 +269,7 @@ impl + AsRef> BlockBuilder { } impl BlockBuilder { - /// Initialize a [`BasicBlock::DFB`] rooted HUGR builder + /// Initialize a [`DFB`] rooted HUGR builder pub fn new( inputs: impl Into, input_extensions: impl Into>, @@ -280,7 +280,7 @@ impl BlockBuilder { let inputs = inputs.into(); let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let other_outputs = other_outputs.into(); - let op = BasicBlock::DFB { + let op = DFB { inputs: inputs.clone(), other_outputs: other_outputs.clone(), tuple_sum_rows: tuple_sum_rows.clone(), diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 0ec275ee7..d0da85156 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -462,7 +462,7 @@ fn make_block( let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone()); let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type]) .with_extension_delta(&extension_delta.clone()); - let dfb = ops::BasicBlock::DFB { + let dfb = ops::DFB { inputs, other_outputs: type_row![], tuple_sum_rows, @@ -497,7 +497,7 @@ fn create_entry_exit( exit_types: impl Into, ) -> Result<([Node; 3], Node), Box> { let entry_tuple_sum = Type::new_tuple_sum(entry_variants.clone()); - let dfb = ops::BasicBlock::DFB { + let dfb = ops::DFB { inputs: inputs.clone(), other_outputs: type_row![], tuple_sum_rows: entry_variants, @@ -506,7 +506,7 @@ fn create_entry_exit( let exit = hugr.add_node_with_parent( root, - ops::BasicBlock::Exit { + ops::Exit { cfg_outputs: exit_types.into(), }, )?; diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index c13a4183f..ff19a01a7 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -12,9 +12,10 @@ use crate::hugr::rewrite::Rewrite; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; +use crate::ops::controlflow::BasicBlock; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; -use crate::ops::{BasicBlock, OpType}; +use crate::ops::{OpType, DFB}; use crate::PortIndex; use crate::{type_row, Node}; @@ -114,12 +115,13 @@ impl Rewrite for OutlineCfg { self.compute_entry_exit_outside_extensions(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() - let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else { + let OpType::DFB(DFB { inputs, .. }) = h.get_optype(entry) else { panic!("Entry node is not a basic block") }; let inputs = inputs.clone(); let outputs = match h.get_optype(outside) { - OpType::BasicBlock(b) => b.dataflow_input().clone(), + OpType::DFB(dfb) => dfb.dataflow_input().clone(), + OpType::Exit(exit) => exit.dataflow_input().clone(), _ => panic!("External successor not a basic block"), }; let outer_cfg = h.get_parent(entry).unwrap(); @@ -265,7 +267,6 @@ mod test { use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; - use crate::ops::{BasicBlock, OpType}; use crate::types::FunctionType; use crate::{type_row, HugrView, Node}; use cool_asserts::assert_matches; @@ -348,12 +349,9 @@ mod test { h.output_neighbours(tail).take(2).collect::>(), HashSet::from([exit, new_block]) ); - assert_matches!( - h.get_optype(new_block), - OpType::BasicBlock(BasicBlock::DFB { .. }) - ); + assert!(h.get_optype(new_block).is_dfb()); assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block)); - assert_matches!(h.base_hugr().get_optype(new_cfg), OpType::CFG(_)); + assert!(h.base_hugr().get_optype(new_cfg).is_cfg()); } #[test] @@ -409,12 +407,9 @@ mod test { .unwrap(); h.update_validate(&PRELUDE_REGISTRY).unwrap(); assert_eq!(new_block, h.children(h.root()).next().unwrap()); - assert_matches!( - h.get_optype(new_block), - OpType::BasicBlock(BasicBlock::DFB { .. }) - ); + assert!(h.get_optype(new_block).is_dfb()); assert_eq!(h.get_parent(new_cfg), Some(new_block)); - assert_matches!(h.get_optype(new_cfg), OpType::CFG(_)); + assert!(h.get_optype(new_cfg).is_cfg()); for n in other_blocks { assert_eq!(depth(&h, n), 1); } diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 017407c97..7ae04d6b5 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -449,7 +449,7 @@ mod test { use crate::ops::custom::{ExternalOp, OpaqueOp}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; - use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG}; + use crate::ops::{self, Case, LeafOp, OpTag, OpType, DFB, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; @@ -504,14 +504,8 @@ mod test { let popp = h.get_parent(pop).unwrap(); let pushp = h.get_parent(push).unwrap(); assert_ne!(popp, pushp); // Two different BBs - assert!(matches!( - h.get_optype(popp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); - assert!(matches!( - h.get_optype(pushp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); + assert!(h.get_optype(popp).is_dfb()); + assert!(h.get_optype(pushp).is_dfb()); assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap()); } @@ -523,7 +517,7 @@ mod test { })); let r_bb = replacement.add_node_with_parent( replacement.root(), - BasicBlock::DFB { + DFB { inputs: vec![listy.clone()].into(), tuple_sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), @@ -596,10 +590,7 @@ mod test { let grandp = h.get_parent(popp).unwrap(); assert_eq!(grandp, h.get_parent(pushp).unwrap()); - assert!(matches!( - h.get_optype(grandp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); + assert!(h.get_optype(grandp).is_dfb()); } Ok(()) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index dc8e9add2..5ef926d4a 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -289,7 +289,7 @@ fn cfg_children_restrictions() { let block = b .add_node_with_parent( cfg, - ops::BasicBlock::DFB { + ops::DFB { inputs: type_row![BOOL_T], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![BOOL_T], @@ -301,7 +301,7 @@ fn cfg_children_restrictions() { let exit = b .add_node_with_parent( cfg, - ops::BasicBlock::Exit { + ops::Exit { cfg_outputs: type_row![BOOL_T], }, ) @@ -315,7 +315,7 @@ fn cfg_children_restrictions() { let exit2 = b .add_node_after( exit, - ops::BasicBlock::Exit { + ops::Exit { cfg_outputs: type_row![BOOL_T], }, ) @@ -330,7 +330,7 @@ fn cfg_children_restrictions() { // Change the types in the BasicBlock node to work on qubits instead of bits b.replace_op( block, - NodeType::new_pure(ops::BasicBlock::DFB { + NodeType::new_pure(ops::DFB { inputs: type_row![Q], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![Q], diff --git a/src/hugr/views/root_checked.rs b/src/hugr/views/root_checked.rs index 6b3f7aba3..63c909d51 100644 --- a/src/hugr/views/root_checked.rs +++ b/src/hugr/views/root_checked.rs @@ -74,7 +74,7 @@ mod test { use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID}; - use crate::ops::{BasicBlock, LeafOp, OpTag}; + use crate::ops::{LeafOp, OpTag, DFB}; use crate::{ops, type_row, types::FunctionType, Hugr, HugrView}; #[test] @@ -94,7 +94,7 @@ mod test { let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); // That is a HugrMutInternal, so we can try: let root = dfg_v.root(); - let bb = NodeType::new_pure(BasicBlock::DFB { + let bb = NodeType::new_pure(DFB { inputs: type_row![], other_outputs: type_row![], tuple_sum_rows: vec![type_row![]], diff --git a/src/ops.rs b/src/ops.rs index c0c8e48ae..313638020 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -21,7 +21,7 @@ use smol_str::SmolStr; use enum_dispatch::enum_dispatch; pub use constant::Const; -pub use controlflow::{BasicBlock, Case, Conditional, TailLoop, CFG}; +pub use controlflow::{Case, Conditional, Exit, TailLoop, CFG, DFB}; pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG}; pub use leaf::LeafOp; pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module}; @@ -48,7 +48,8 @@ pub enum OpType { LoadConstant, DFG, LeafOp, - BasicBlock, + DFB, + Exit, TailLoop, CFG, Conditional, @@ -93,7 +94,8 @@ impl_op_ref_try_into!(CallIndirect); impl_op_ref_try_into!(LoadConstant); impl_op_ref_try_into!(DFG, dfg); impl_op_ref_try_into!(LeafOp); -impl_op_ref_try_into!(BasicBlock); +impl_op_ref_try_into!(DFB, dfb); +impl_op_ref_try_into!(Exit); impl_op_ref_try_into!(TailLoop); impl_op_ref_try_into!(CFG, cfg); impl_op_ref_try_into!(Conditional); @@ -344,6 +346,7 @@ impl OpParent for TailLoop {} impl OpParent for CFG {} impl OpParent for Conditional {} impl OpParent for FuncDecl {} +impl OpParent for Exit {} #[enum_dispatch] /// Methods for Ops to validate themselves and children diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 351796b94..f5ab94c57 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -115,70 +115,60 @@ impl DataflowOpTrait for CFG { } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "block")] -/// Basic block ops - nodes valid in control flow graphs. +/// A CFG basic block node. The signature is that of the internal Dataflow graph. #[allow(missing_docs)] -pub enum BasicBlock { - /// A CFG basic block node. The signature is that of the internal Dataflow graph. - DFB { - inputs: TypeRow, - other_outputs: TypeRow, - tuple_sum_rows: Vec, - extension_delta: ExtensionSet, - }, - /// The single exit node of the CFG, has no children, - /// stores the types of the CFG node output. - Exit { cfg_outputs: TypeRow }, +pub struct DFB { + pub inputs: TypeRow, + pub other_outputs: TypeRow, + pub tuple_sum_rows: Vec, + pub extension_delta: ExtensionSet, +} + +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +/// The single exit node of the CFG, has no children, +/// stores the types of the CFG node output. +pub struct Exit { + /// Output type row of the CFG. + pub cfg_outputs: TypeRow, } -impl OpName for BasicBlock { - /// The name of the operation. +impl OpName for DFB { fn name(&self) -> SmolStr { - match self { - BasicBlock::DFB { .. } => "DFB".into(), - BasicBlock::Exit { .. } => "Exit".into(), - } + "DFB".into() } } -impl StaticTag for BasicBlock { +impl OpName for Exit { + fn name(&self) -> SmolStr { + "Exit".into() + } +} + +impl StaticTag for DFB { const TAG: OpTag = OpTag::BasicBlock; } -impl DataflowParent for BasicBlock { +impl StaticTag for Exit { + const TAG: OpTag = OpTag::BasicBlockExit; +} + +impl DataflowParent for DFB { fn inner_signature(&self) -> FunctionType { - match self { - BasicBlock::DFB { - inputs, - other_outputs, - tuple_sum_rows, - .. - } => { - // The node outputs a TupleSum before the data outputs of the block node - let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone()); - let mut node_outputs = vec![tuple_sum_type]; - node_outputs.extend_from_slice(other_outputs); - FunctionType::new(inputs.clone(), TypeRow::from(node_outputs)) - } - BasicBlock::Exit { cfg_outputs } => FunctionType::new(type_row![], cfg_outputs.clone()), - } + // The node outputs a TupleSum before the data outputs of the block node + let tuple_sum_type = Type::new_tuple_sum(self.tuple_sum_rows.clone()); + let mut node_outputs = vec![tuple_sum_type]; + node_outputs.extend_from_slice(&self.other_outputs); + FunctionType::new(self.inputs.clone(), TypeRow::from(node_outputs)) } } -impl OpTrait for BasicBlock { - /// The description of the operation. +impl OpTrait for DFB { fn description(&self) -> &str { - match self { - BasicBlock::DFB { .. } => "A CFG basic block node", - BasicBlock::Exit { .. } => "A CFG exit block node", - } + "A CFG basic block node" } /// Tag identifying the operation. fn tag(&self) -> OpTag { - match self { - BasicBlock::DFB { .. } => OpTag::BasicBlock, - BasicBlock::Exit { .. } => OpTag::BasicBlockExit, - } + Self::TAG } fn other_input(&self) -> Option { @@ -190,43 +180,73 @@ impl OpTrait for BasicBlock { } fn dataflow_signature(&self) -> Option { - Some(match self { - BasicBlock::DFB { - extension_delta, .. - } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), - BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), - }) + Some( + FunctionType::new(type_row![], type_row![]).with_extension_delta(&self.extension_delta), + ) } fn non_df_port_count(&self, dir: Direction) -> usize { - match self { - Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => tuple_sum_rows.len(), - Self::Exit { .. } if dir == Direction::Outgoing => 0, - _ => 1, + match dir { + Direction::Incoming => 1, + Direction::Outgoing => self.tuple_sum_rows.len(), } } } -impl BasicBlock { - /// The input signature of the contained dataflow graph. - pub fn dataflow_input(&self) -> &TypeRow { - match self { - BasicBlock::DFB { inputs, .. } => inputs, - BasicBlock::Exit { cfg_outputs } => cfg_outputs, +impl OpTrait for Exit { + fn description(&self) -> &str { + "A CFG exit block node" + } + /// Tag identifying the operation. + fn tag(&self) -> OpTag { + Self::TAG + } + + fn other_input(&self) -> Option { + Some(EdgeKind::ControlFlow) + } + + fn other_output(&self) -> Option { + Some(EdgeKind::ControlFlow) + } + + fn dataflow_signature(&self) -> Option { + Some(FunctionType::new(type_row![], type_row![])) + } + + fn non_df_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => 1, + Direction::Outgoing => 0, } } +} + +/// Functionality shared by DFB and Exit CFG block types. +pub trait BasicBlock { + /// The input dataflow signature of the CFG block. + fn dataflow_input(&self) -> &TypeRow; +} +impl BasicBlock for DFB { + fn dataflow_input(&self) -> &TypeRow { + &self.inputs + } +} +impl DFB { /// The correct inputs of any successors. Returns None if successor is not a /// valid index. pub fn successor_input(&self, successor: usize) -> Option { - match self { - BasicBlock::DFB { - tuple_sum_rows, - other_outputs: outputs, - .. - } => Some(tuple_sum_first(tuple_sum_rows.get(successor)?, outputs)), - BasicBlock::Exit { .. } => panic!("Exit should have no successors"), - } + Some(tuple_sum_first( + self.tuple_sum_rows.get(successor)?, + &self.other_outputs, + )) + } +} + +impl BasicBlock for Exit { + fn dataflow_input(&self) -> &TypeRow { + &self.cfg_outputs } } diff --git a/src/ops/handle.rs b/src/ops/handle.rs index f62524f4e..118f3410d 100644 --- a/src/ops/handle.rs +++ b/src/ops/handle.rs @@ -103,7 +103,7 @@ impl AliasID { pub struct ConstID(Node); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] -/// Handle to a [BasicBlock](crate::ops::BasicBlock) node. +/// Handle to a [DFB](crate::ops::DFB) or [Exit](crate::ops::Exit) node. pub struct BasicBlockID(Node); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 357da2bc8..6c1e653a2 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -6,14 +6,15 @@ //! It also defines a `validate_op_children` method for more complex tests that //! require traversing the children. +use crate::types::TypeRow; use itertools::Itertools; use portgraph::{NodeIndex, PortOffset}; +use std::any::type_name; use thiserror::Error; -use crate::types::{FunctionType, Type, TypeRow}; - +use super::controlflow::BasicBlock; use super::dataflow::DataflowParent; -use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; +use super::{impl_validate_op, Exit, OpTag, OpTrait, OpType, ValidateOp}; /// A set of property flags required for an operation. #[non_exhaustive] @@ -62,50 +63,6 @@ impl ValidateOp for super::Module { } } -impl ValidateOp for super::FuncDefn { - fn validity_flags(&self) -> OpValidityFlags { - OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - } - } - - fn validate_op_children<'a>( - &self, - children: impl DoubleEndedIterator, - ) -> Result<(), ChildrenValidationError> { - // We check type-variables are declared in `validate_subtree`, so here - // we can just assume all type variables are valid regardless of binders. - let FunctionType { input, output, .. } = self.inner_signature(); - validate_io_nodes(&input, &output, "function definition", children) - } -} - -impl ValidateOp for super::DFG { - fn validity_flags(&self) -> OpValidityFlags { - OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - } - } - - fn validate_op_children<'a>( - &self, - children: impl DoubleEndedIterator, - ) -> Result<(), ChildrenValidationError> { - let sig = self.dataflow_signature().unwrap_or_default(); - validate_io_nodes(&sig.input, &sig.output, "nested graph", children) - } -} - impl ValidateOp for super::Conditional { fn validity_flags(&self) -> OpValidityFlags { OpValidityFlags { @@ -287,46 +244,7 @@ pub struct ChildrenEdgeData { pub target_port: PortOffset, } -impl ValidateOp for BasicBlock { - /// Returns the set of allowed parent operation types. - fn validity_flags(&self) -> OpValidityFlags { - match self { - BasicBlock::DFB { .. } => OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - }, - // Default flags are valid for non-container operations - BasicBlock::Exit { .. } => Default::default(), - } - } - - /// Validate the ordered list of children. - fn validate_op_children<'a>( - &self, - children: impl DoubleEndedIterator, - ) -> Result<(), ChildrenValidationError> { - match self { - BasicBlock::DFB { - inputs, - tuple_sum_rows: tuple_sum_variants, - other_outputs: outputs, - extension_delta: _, - } => { - let tuple_sum_type = Type::new_tuple_sum(tuple_sum_variants.clone()); - let node_outputs: TypeRow = [&[tuple_sum_type], outputs.as_ref()].concat().into(); - validate_io_nodes(inputs, &node_outputs, "basic block graph", children) - } - // Exit nodes do not have children - BasicBlock::Exit { .. } => Ok(()), - } - } -} - -impl ValidateOp for super::Case { +impl ValidateOp for T { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { OpValidityFlags { @@ -345,7 +263,7 @@ impl ValidateOp for super::Case { children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { let sig = self.inner_signature(); - validate_io_nodes(&sig.input, &sig.output, "Conditional", children) + validate_io_nodes(&sig.input, &sig.output, type_name::(), children) } } @@ -409,14 +327,18 @@ fn validate_io_nodes<'a>( /// Validate an edge between two basic blocks in a CFG sibling graph. fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> { - let [source, target]: [&BasicBlock; 2] = [&edge.source_op, &edge.target_op].map(|op| { - let block_op = op - .as_basic_block() - .expect("CFG sibling graphs can only contain basic block operations."); - block_op - }); - - if source.successor_input(edge.source_port.index()).as_ref() != Some(target.dataflow_input()) { + let source = &edge + .source_op + .as_dfb() + .expect("CFG sibling graphs can only contain basic block operations."); + + let target_input = match &edge.target_op { + OpType::DFB(dfb) => dfb.dataflow_input(), + OpType::Exit(exit) => exit.dataflow_input(), + _ => panic!("CFG sibling graphs can only contain basic block operations."), + }; + + if source.successor_input(edge.source_port.index()).as_ref() != Some(target_input) { return Err(EdgeValidationError::CFGEdgeSignatureMismatch { edge }); } @@ -495,3 +417,4 @@ impl_validate_op!(Call); impl_validate_op!(LoadConstant); impl_validate_op!(CallIndirect); impl_validate_op!(LeafOp); +impl_validate_op!(Exit);