Skip to content

Commit

Permalink
Avoid need to rerun extension inference after outline-cfg rewrite (#534)
Browse files Browse the repository at this point in the history
Mostly by updating the builder to allow specifying ExtensionSets in
places not previously possible: cfg_builder(), BlockBuilder::new, and
BlockBuilder::create.

I could go further: there is `fn block_builder` (and `fn
simple_block_builder` and others), which frustratingly take
`FunctionType`s (that I cannot update to Signature because Sig requires
an ExtensionSet not an Option thereof, so extra argument again)....

I guess requiring to specify `None` as extra param in the builder, may
be an OK price to keep the current behaviour of "run inference later"??

The above then avoids needing to `infer_and_validate` in OutlineCfg
tests after the rewrite, just `validate` works. It'd be good to keep the
other Inline/Outline rewrites and InsertIdentity like this too, but it's
probably not a desirable goal for anything taking a subgraph from the
user (i.e. (Simple/)Replace), where we should certainly allow to infer
requirements across the newly-inserted bit of graph...
  • Loading branch information
acl-cqc committed Sep 14, 2023
1 parent 35a0ef0 commit 79de213
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
12 changes: 8 additions & 4 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ pub trait Dataflow: Container {
fn cfg_builder(
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
input_extensions: impl Into<Option<ExtensionSet>>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
Expand All @@ -331,10 +332,13 @@ pub trait Dataflow: Container {

let (cfg_node, _) = add_node_with_wires(
self,
NodeType::open_extensions(ops::CFG {
signature: FunctionType::new(inputs.clone(), output_types.clone())
.with_extension_delta(&extension_delta),
}),
NodeType::new(
ops::CFG {
signature: FunctionType::new(inputs.clone(), output_types.clone())
.with_extension_delta(&extension_delta),
},
input_extensions,
),
input_wires,
)?;
CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
Expand Down
11 changes: 9 additions & 2 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,12 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
let mut node_outputs = vec![predicate_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = FunctionType::new(inputs, TypeRow::from(node_outputs));
let db = DFGBuilder::create_with_io(base, block_n, signature, None)?;
let inp_ex = base
.as_ref()
.get_nodetype(block_n)
.input_extensions()
.cloned();
let db = DFGBuilder::create_with_io(base, block_n, signature, inp_ex)?;
Ok(BlockBuilder::from_dfg_builder(db))
}

Expand All @@ -287,6 +292,7 @@ impl BlockBuilder<Hugr> {
/// Initialize a [`BasicBlock::DFB`] rooted HUGR builder
pub fn new(
inputs: impl Into<TypeRow>,
input_extensions: impl Into<Option<ExtensionSet>>,
predicate_variants: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
Expand All @@ -301,7 +307,7 @@ impl BlockBuilder<Hugr> {
extension_delta,
};

let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new(op, input_extensions));
let root = base.root();
Self::create(base, root, predicate_variants, other_outputs, inputs)
}
Expand Down Expand Up @@ -340,6 +346,7 @@ mod test {
let cfg_id = {
let mut cfg_builder = func_builder.cfg_builder(
vec![(NAT, int)],
None,
type_row![NAT],
ExtensionSet::new(),
)?;
Expand Down
8 changes: 5 additions & 3 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ impl Rewrite for OutlineCfg {

// 2. new_block contains input node, sub-cfg, exit node all connected
let new_block = {
let input_extensions = h.get_nodetype(entry).input_extensions().cloned();
let mut new_block_bldr = BlockBuilder::new(
inputs.clone(),
input_extensions.clone(),
vec![type_row![]],
outputs.clone(),
extension_delta.clone(),
Expand All @@ -126,7 +128,7 @@ impl Rewrite for OutlineCfg {
// N.B. By invoking the cfg_builder, we're forgetting any input
// extensions that may have existed on the original CFG.
let cfg = new_block_bldr
.cfg_builder(wires_in, outputs, extension_delta)
.cfg_builder(wires_in, input_extensions, outputs, extension_delta)
.unwrap();
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block_bldr
Expand Down Expand Up @@ -291,7 +293,7 @@ mod test {
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
let blocks = [head, left, right, merge];
h.apply_rewrite(OutlineCfg::new(blocks)).unwrap();
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
h.validate(&PRELUDE_REGISTRY).unwrap();
for n in blocks {
assert_eq!(depth(&h, n), 3);
}
Expand Down Expand Up @@ -326,7 +328,7 @@ mod test {
}
h.apply_rewrite(OutlineCfg::new(blocks_to_move.iter().copied()))
.unwrap();
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
h.validate(&PRELUDE_REGISTRY).unwrap();
let new_entry = h.children(h.root()).next().unwrap();
for n in other_blocks {
assert_eq!(depth(&h, n), 1);
Expand Down

0 comments on commit 79de213

Please sign in to comment.