Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: unwrap BasicBlock enum #772

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{
};

use crate::ops::handle::NodeHandle;
use crate::ops::{self, dataflow::DataflowParent, BasicBlock, OpType};
use crate::ops::{self, dataflow::DataflowParent, Exit, OpType, DFB};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
Expand Down Expand Up @@ -86,7 +86,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
output: TypeRow,
) -> Result<Self, BuildError> {
let n_out_wires = output.len();
let exit_block_type = OpType::BasicBlock(BasicBlock::Exit {
let exit_block_type = OpType::Exit(Exit {
cfg_outputs: output,
});
let exit_node = base
Expand All @@ -102,7 +102,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
})
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for a non-entry [`DFB`] child graph with `inputs`
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let op = OpType::BasicBlock(BasicBlock::DFB {
let op = OpType::DFB(DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
Expand All @@ -153,7 +153,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
BlockBuilder::create(self.hugr_mut(), block_n)
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for a non-entry [`DFB`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
Expand All @@ -172,7 +172,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for the entry [`DFB`] child graph with `inputs`
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
Expand All @@ -192,7 +192,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// Return a builder for the entry [`DFB`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
Expand Down Expand Up @@ -229,7 +229,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}
}

/// Builder for a [`BasicBlock::DFB`] child graph.
/// Builder for a [`DFB`] child graph.
pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;

impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
Expand All @@ -243,7 +243,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
}
fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
let block_op = base.get_optype(block_n).as_basic_block().unwrap();
let block_op = base.get_optype(block_n).as_dfb().unwrap();
let signature = block_op.inner_signature();
let inp_ex = base
.as_ref()
Expand All @@ -269,7 +269,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
}

impl BlockBuilder<Hugr> {
/// Initialize a [`BasicBlock::DFB`] rooted HUGR builder
/// Initialize a [`DFB`] rooted HUGR builder
pub fn new(
inputs: impl Into<TypeRow>,
input_extensions: impl Into<Option<ExtensionSet>>,
Expand All @@ -280,7 +280,7 @@ impl BlockBuilder<Hugr> {
let inputs = inputs.into();
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let other_outputs = other_outputs.into();
let op = BasicBlock::DFB {
let op = DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
Expand Down
6 changes: 3 additions & 3 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ fn make_block(
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone());
let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type])
.with_extension_delta(&extension_delta.clone());
let dfb = ops::BasicBlock::DFB {
let dfb = ops::DFB {
inputs,
other_outputs: type_row![],
tuple_sum_rows,
Expand Down Expand Up @@ -497,7 +497,7 @@ fn create_entry_exit(
exit_types: impl Into<TypeRow>,
) -> Result<([Node; 3], Node), Box<dyn Error>> {
let entry_tuple_sum = Type::new_tuple_sum(entry_variants.clone());
let dfb = ops::BasicBlock::DFB {
let dfb = ops::DFB {
inputs: inputs.clone(),
other_outputs: type_row![],
tuple_sum_rows: entry_variants,
Expand All @@ -506,7 +506,7 @@ fn create_entry_exit(

let exit = hugr.add_node_with_parent(
root,
ops::BasicBlock::Exit {
ops::Exit {
cfg_outputs: exit_types.into(),
},
)?;
Expand Down
23 changes: 9 additions & 14 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::controlflow::BasicBlock;
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpType};
use crate::ops::{OpType, DFB};
use crate::PortIndex;
use crate::{type_row, Node};

Expand Down Expand Up @@ -114,12 +115,13 @@ impl Rewrite for OutlineCfg {
self.compute_entry_exit_outside_extensions(h)?;
// 1. Compute signature
// These panic()s only happen if the Hugr would not have passed validate()
let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else {
let OpType::DFB(DFB { inputs, .. }) = h.get_optype(entry) else {
panic!("Entry node is not a basic block")
};
let inputs = inputs.clone();
let outputs = match h.get_optype(outside) {
OpType::BasicBlock(b) => b.dataflow_input().clone(),
OpType::DFB(dfb) => dfb.dataflow_input().clone(),
OpType::Exit(exit) => exit.dataflow_input().clone(),
_ => panic!("External successor not a basic block"),
};
let outer_cfg = h.get_parent(entry).unwrap();
Expand Down Expand Up @@ -265,7 +267,6 @@ mod test {
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpType};
use crate::types::FunctionType;
use crate::{type_row, HugrView, Node};
use cool_asserts::assert_matches;
Expand Down Expand Up @@ -348,12 +349,9 @@ mod test {
h.output_neighbours(tail).take(2).collect::<HashSet<Node>>(),
HashSet::from([exit, new_block])
);
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert!(h.get_optype(new_block).is_dfb());
assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block));
assert_matches!(h.base_hugr().get_optype(new_cfg), OpType::CFG(_));
assert!(h.base_hugr().get_optype(new_cfg).is_cfg());
}

#[test]
Expand Down Expand Up @@ -409,12 +407,9 @@ mod test {
.unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(new_block, h.children(h.root()).next().unwrap());
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert!(h.get_optype(new_block).is_dfb());
assert_eq!(h.get_parent(new_cfg), Some(new_block));
assert_matches!(h.get_optype(new_cfg), OpType::CFG(_));
assert!(h.get_optype(new_cfg).is_cfg());
for n in other_blocks {
assert_eq!(depth(&h, n), 1);
}
Expand Down
19 changes: 5 additions & 14 deletions src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ mod test {
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG};
use crate::ops::{self, Case, LeafOp, OpTag, OpType, DFB, DFG};
use crate::std_extensions::collections;
use crate::types::{FunctionType, Type, TypeArg, TypeRow};
use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort};
Expand Down Expand Up @@ -504,14 +504,8 @@ mod test {
let popp = h.get_parent(pop).unwrap();
let pushp = h.get_parent(push).unwrap();
assert_ne!(popp, pushp); // Two different BBs
assert!(matches!(
h.get_optype(popp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(matches!(
h.get_optype(pushp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(h.get_optype(popp).is_dfb());
assert!(h.get_optype(pushp).is_dfb());

assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
}
Expand All @@ -523,7 +517,7 @@ mod test {
}));
let r_bb = replacement.add_node_with_parent(
replacement.root(),
BasicBlock::DFB {
DFB {
inputs: vec![listy.clone()].into(),
tuple_sum_rows: vec![type_row![]],
other_outputs: vec![listy.clone()].into(),
Expand Down Expand Up @@ -596,10 +590,7 @@ mod test {

let grandp = h.get_parent(popp).unwrap();
assert_eq!(grandp, h.get_parent(pushp).unwrap());
assert!(matches!(
h.get_optype(grandp),
OpType::BasicBlock(BasicBlock::DFB { .. })
));
assert!(h.get_optype(grandp).is_dfb());
}

Ok(())
Expand Down
8 changes: 4 additions & 4 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ fn cfg_children_restrictions() {
let block = b
.add_node_with_parent(
cfg,
ops::BasicBlock::DFB {
ops::DFB {
inputs: type_row![BOOL_T],
tuple_sum_rows: vec![type_row![]],
other_outputs: type_row![BOOL_T],
Expand All @@ -301,7 +301,7 @@ fn cfg_children_restrictions() {
let exit = b
.add_node_with_parent(
cfg,
ops::BasicBlock::Exit {
ops::Exit {
cfg_outputs: type_row![BOOL_T],
},
)
Expand All @@ -315,7 +315,7 @@ fn cfg_children_restrictions() {
let exit2 = b
.add_node_after(
exit,
ops::BasicBlock::Exit {
ops::Exit {
cfg_outputs: type_row![BOOL_T],
},
)
Expand All @@ -330,7 +330,7 @@ fn cfg_children_restrictions() {
// Change the types in the BasicBlock node to work on qubits instead of bits
b.replace_op(
block,
NodeType::new_pure(ops::BasicBlock::DFB {
NodeType::new_pure(ops::DFB {
inputs: type_row![Q],
tuple_sum_rows: vec![type_row![]],
other_outputs: type_row![Q],
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ mod test {
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID};
use crate::ops::{BasicBlock, LeafOp, OpTag};
use crate::ops::{LeafOp, OpTag, DFB};
use crate::{ops, type_row, types::FunctionType, Hugr, HugrView};

#[test]
Expand All @@ -94,7 +94,7 @@ mod test {
let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
// That is a HugrMutInternal, so we can try:
let root = dfg_v.root();
let bb = NodeType::new_pure(BasicBlock::DFB {
let bb = NodeType::new_pure(DFB {
inputs: type_row![],
other_outputs: type_row![],
tuple_sum_rows: vec![type_row![]],
Expand Down
9 changes: 6 additions & 3 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use smol_str::SmolStr;
use enum_dispatch::enum_dispatch;

pub use constant::Const;
pub use controlflow::{BasicBlock, Case, Conditional, TailLoop, CFG};
pub use controlflow::{Case, Conditional, Exit, TailLoop, CFG, DFB};
pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG};
pub use leaf::LeafOp;
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
Expand All @@ -48,7 +48,8 @@ pub enum OpType {
LoadConstant,
DFG,
LeafOp,
BasicBlock,
DFB,
Exit,
TailLoop,
CFG,
Conditional,
Expand Down Expand Up @@ -93,7 +94,8 @@ impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(LeafOp);
impl_op_ref_try_into!(BasicBlock);
impl_op_ref_try_into!(DFB, dfb);
impl_op_ref_try_into!(Exit);
impl_op_ref_try_into!(TailLoop);
impl_op_ref_try_into!(CFG, cfg);
impl_op_ref_try_into!(Conditional);
Expand Down Expand Up @@ -344,6 +346,7 @@ impl OpParent for TailLoop {}
impl OpParent for CFG {}
impl OpParent for Conditional {}
impl OpParent for FuncDecl {}
impl OpParent for Exit {}

#[enum_dispatch]
/// Methods for Ops to validate themselves and children
Expand Down
Loading