Skip to content

Commit

Permalink
refactor!: unwrap BasicBlock enum (#781)
Browse files Browse the repository at this point in the history
BREAKING_CHANGES: `BasicBlock::{DFB, Exit}` are now standalone structs
`DataflowBlock, ExitBlock`. A level of nesting in serialization is also
removed for these operations.

---------

Co-authored-by: Alan Lawrence <alan.lawrence@quantinuum.com>
  • Loading branch information
ss2165 and acl-cqc committed Jan 5, 2024
1 parent 968c8b0 commit b680662
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 141 deletions.
21 changes: 11 additions & 10 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, BasicBlock, OpType};
use crate::ops::{self, DataflowBlock, ExitBlock, OpType};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
Expand Down Expand Up @@ -86,7 +86,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
output: TypeRow,
) -> Result<Self, BuildError> {
let n_out_wires = output.len();
let exit_block_type = OpType::BasicBlock(BasicBlock::Exit {
let exit_block_type = OpType::ExitBlock(ExitBlock {
cfg_outputs: output,
});
let exit_node = base
Expand All @@ -102,7 +102,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
})
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let op = OpType::BasicBlock(BasicBlock::DFB {
let op = OpType::DataflowBlock(DataflowBlock {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
Expand All @@ -159,7 +159,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
)
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
Expand All @@ -178,7 +178,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
Expand All @@ -198,7 +198,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
Expand Down Expand Up @@ -235,7 +235,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}
}

/// Builder for a [`BasicBlock::DFB`] child graph.
/// Builder for a [`DataflowBlock`] child graph.
pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;

impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
Expand All @@ -248,6 +248,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
) -> Result<(), BuildError> {
Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
}

fn create(
base: B,
block_n: Node,
Expand Down Expand Up @@ -284,7 +285,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
}

impl BlockBuilder<Hugr> {
/// Initialize a [`BasicBlock::DFB`] rooted HUGR builder
/// Initialize a [`DataflowBlock`] rooted HUGR builder
pub fn new(
inputs: impl Into<TypeRow>,
input_extensions: impl Into<Option<ExtensionSet>>,
Expand All @@ -295,7 +296,7 @@ impl BlockBuilder<Hugr> {
let inputs = inputs.into();
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let other_outputs = other_outputs.into();
let op = BasicBlock::DFB {
let op = DataflowBlock {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
Expand Down
6 changes: 3 additions & 3 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ fn make_block(
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone());
let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type])
.with_extension_delta(&extension_delta.clone());
let dfb = ops::BasicBlock::DFB {
let dfb = ops::DataflowBlock {
inputs,
other_outputs: type_row![],
tuple_sum_rows,
Expand Down Expand Up @@ -496,7 +496,7 @@ fn create_entry_exit(
exit_types: impl Into<TypeRow>,
) -> Result<([Node; 3], Node), Box<dyn Error>> {
let entry_tuple_sum = Type::new_tuple_sum(entry_variants.clone());
let dfb = ops::BasicBlock::DFB {
let dfb = ops::DataflowBlock {
inputs: inputs.clone(),
other_outputs: type_row![],
tuple_sum_rows: entry_variants,
Expand All @@ -505,7 +505,7 @@ fn create_entry_exit(

let exit = hugr.add_node_with_parent(
root,
ops::BasicBlock::Exit {
ops::ExitBlock {
cfg_outputs: exit_types.into(),
},
)?;
Expand Down
23 changes: 9 additions & 14 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::controlflow::BasicBlock;
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpType};
use crate::ops::{DataflowBlock, OpType};
use crate::PortIndex;
use crate::{type_row, Node};

Expand Down Expand Up @@ -114,12 +115,13 @@ impl Rewrite for OutlineCfg {
self.compute_entry_exit_outside_extensions(h)?;
// 1. Compute signature
// These panic()s only happen if the Hugr would not have passed validate()
let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else {
let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
panic!("Entry node is not a basic block")
};
let inputs = inputs.clone();
let outputs = match h.get_optype(outside) {
OpType::BasicBlock(b) => b.dataflow_input().clone(),
OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
_ => panic!("External successor not a basic block"),
};
let outer_cfg = h.get_parent(entry).unwrap();
Expand Down Expand Up @@ -265,7 +267,6 @@ mod test {
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpType};
use crate::types::FunctionType;
use crate::{type_row, HugrView, Node};
use cool_asserts::assert_matches;
Expand Down Expand Up @@ -348,12 +349,9 @@ mod test {
h.output_neighbours(tail).take(2).collect::<HashSet<Node>>(),
HashSet::from([exit, new_block])
);
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert!(h.get_optype(new_block).is_dataflow_block());
assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block));
assert_matches!(h.base_hugr().get_optype(new_cfg), OpType::CFG(_));
assert!(h.base_hugr().get_optype(new_cfg).is_cfg());
}

#[test]
Expand Down Expand Up @@ -409,12 +407,9 @@ mod test {
.unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(new_block, h.children(h.root()).next().unwrap());
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert!(h.get_optype(new_block).is_dataflow_block());
assert_eq!(h.get_parent(new_cfg), Some(new_block));
assert_matches!(h.get_optype(new_cfg), OpType::CFG(_));
assert!(h.get_optype(new_cfg).is_cfg());
for n in other_blocks {
assert_eq!(depth(&h, n), 1);
}
Expand Down
19 changes: 5 additions & 14 deletions src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ mod test {
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG};
use crate::ops::{self, Case, DataflowBlock, LeafOp, OpTag, OpType, DFG};
use crate::std_extensions::collections;
use crate::types::{FunctionType, Type, TypeArg, TypeRow};
use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort};
Expand Down Expand Up @@ -504,14 +504,8 @@ mod test {
let popp = h.get_parent(pop).unwrap();
let pushp = h.get_parent(push).unwrap();
assert_ne!(popp, pushp); // Two different BBs
assert!(matches!(
h.get_optype(popp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(matches!(
h.get_optype(pushp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(h.get_optype(popp).is_dataflow_block());
assert!(h.get_optype(pushp).is_dataflow_block());

assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
}
Expand All @@ -523,7 +517,7 @@ mod test {
}));
let r_bb = replacement.add_node_with_parent(
replacement.root(),
BasicBlock::DFB {
DataflowBlock {
inputs: vec![listy.clone()].into(),
tuple_sum_rows: vec![type_row![]],
other_outputs: vec![listy.clone()].into(),
Expand Down Expand Up @@ -596,10 +590,7 @@ mod test {

let grandp = h.get_parent(popp).unwrap();
assert_eq!(grandp, h.get_parent(pushp).unwrap());
assert!(matches!(
h.get_optype(grandp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(h.get_optype(grandp).is_dataflow_block());
}

Ok(())
Expand Down
8 changes: 4 additions & 4 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ fn cfg_children_restrictions() {
let block = b
.add_node_with_parent(
cfg,
ops::BasicBlock::DFB {
ops::DataflowBlock {
inputs: type_row![BOOL_T],
tuple_sum_rows: vec![type_row![]],
other_outputs: type_row![BOOL_T],
Expand All @@ -301,7 +301,7 @@ fn cfg_children_restrictions() {
let exit = b
.add_node_with_parent(
cfg,
ops::BasicBlock::Exit {
ops::ExitBlock {
cfg_outputs: type_row![BOOL_T],
},
)
Expand All @@ -315,7 +315,7 @@ fn cfg_children_restrictions() {
let exit2 = b
.add_node_after(
exit,
ops::BasicBlock::Exit {
ops::ExitBlock {
cfg_outputs: type_row![BOOL_T],
},
)
Expand All @@ -330,7 +330,7 @@ fn cfg_children_restrictions() {
// Change the types in the BasicBlock node to work on qubits instead of bits
b.replace_op(
block,
NodeType::new_pure(ops::BasicBlock::DFB {
NodeType::new_pure(ops::DataflowBlock {
inputs: type_row![Q],
tuple_sum_rows: vec![type_row![]],
other_outputs: type_row![Q],
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ mod test {
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID};
use crate::ops::{BasicBlock, LeafOp, OpTag};
use crate::ops::{DataflowBlock, LeafOp, OpTag};
use crate::{ops, type_row, types::FunctionType, Hugr, HugrView};

#[test]
Expand All @@ -94,7 +94,7 @@ mod test {
let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
// That is a HugrMutInternal, so we can try:
let root = dfg_v.root();
let bb = NodeType::new_pure(BasicBlock::DFB {
let bb = NodeType::new_pure(DataflowBlock {
inputs: type_row![],
other_outputs: type_row![],
tuple_sum_rows: vec![type_row![]],
Expand Down
8 changes: 5 additions & 3 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use smol_str::SmolStr;
use enum_dispatch::enum_dispatch;

pub use constant::Const;
pub use controlflow::{BasicBlock, Case, Conditional, TailLoop, CFG};
pub use controlflow::{Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use dataflow::{Call, CallIndirect, Input, LoadConstant, Output, DFG};
pub use leaf::LeafOp;
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
Expand All @@ -48,7 +48,8 @@ pub enum OpType {
LoadConstant,
DFG,
LeafOp,
BasicBlock,
DataflowBlock,
ExitBlock,
TailLoop,
CFG,
Conditional,
Expand Down Expand Up @@ -93,7 +94,8 @@ impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(LeafOp);
impl_op_ref_try_into!(BasicBlock);
impl_op_ref_try_into!(DataflowBlock);
impl_op_ref_try_into!(ExitBlock);
impl_op_ref_try_into!(TailLoop);
impl_op_ref_try_into!(CFG, cfg);
impl_op_ref_try_into!(Conditional);
Expand Down
Loading

0 comments on commit b680662

Please sign in to comment.