Skip to content

Commit

Permalink
[Refactor] Tidy some CFG structuring code, require RootHandle=CfgID (#…
Browse files Browse the repository at this point in the history
…597)

* IdentityCfgMap and HalfNodeView require RootTagged w/ CfgID rather
than just HugrView
* Derive `Clone` for RootHandle
* implement a `fn borrow` (returning `RootHandle<&Hugr>`) for
`RootHandle<&mut Hugr`
* Avoid some redundant checks and unwraps, remove a misleading `pub`
(the structure it's on isn't pub)
  • Loading branch information
acl-cqc committed Oct 13, 2023
1 parent 9254ac7 commit f570b97
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
10 changes: 6 additions & 4 deletions src/algorithm/half_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::hash::Hash;

use super::nest_cfgs::CfgNodeMap;

use crate::hugr::views::HugrView;
use crate::hugr::RootTagged;

use crate::ops::handle::CfgID;
use crate::ops::{OpTag, OpTrait};

use crate::{Direction, Node};
Expand All @@ -30,7 +31,7 @@ struct HalfNodeView<H> {
exit: Node,
}

impl<H: HugrView> HalfNodeView<H> {
impl<H: RootTagged<RootHandle = CfgID>> HalfNodeView<H> {
#[allow(unused)]
pub(crate) fn new(h: H) -> Self {
let (entry, exit) = {
Expand Down Expand Up @@ -62,7 +63,7 @@ impl<H: HugrView> HalfNodeView<H> {
}
}

impl<H: HugrView> CfgNodeMap<HalfNode> for HalfNodeView<H> {
impl<H: RootTagged<RootHandle = CfgID>> CfgNodeMap<HalfNode> for HalfNodeView<H> {
type Iterator<'c> = <Vec<HalfNode> as IntoIterator>::IntoIter where Self: 'c;
fn entry_node(&self) -> HalfNode {
HalfNode::N(self.entry)
Expand Down Expand Up @@ -97,6 +98,7 @@ mod test {
use super::super::nest_cfgs::{test::*, EdgeClassifier};
use super::{HalfNode, HalfNodeView};
use crate::builder::BuildError;
use crate::hugr::views::RootChecked;
use crate::ops::handle::NodeHandle;

use itertools::Itertools;
Expand All @@ -116,7 +118,7 @@ mod test {
// \---<---<---<---<---<---<---<---<---<---/
// Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example)

let v = HalfNodeView::new(&h);
let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap());

let edge_classes = EdgeClassifier::get_edge_classes(&v);
let HalfNodeView { h: _, entry, exit } = v;
Expand Down
33 changes: 16 additions & 17 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use thiserror::Error;
use crate::hugr::rewrite::outline_cfg::OutlineCfg;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::views::{HierarchyView, HugrView, SiblingGraph};
use crate::hugr::{HugrMut, Rewrite};
use crate::hugr::{HugrMut, Rewrite, RootTagged};
use crate::ops::handle::{BasicBlockID, CfgID};
use crate::ops::OpTag;
use crate::ops::OpTrait;
Expand Down Expand Up @@ -158,9 +158,7 @@ pub fn transform_cfg_to_nested<T: Copy + Eq + Hash + std::fmt::Debug>(
pub fn transform_all_cfgs(h: &mut Hugr) {
let mut node_stack = Vec::from([h.root()]);
while let Some(n) = node_stack.pop() {
if h.get_optype(n).tag() == OpTag::Cfg {
// We've checked the optype so this should be fine
let s = SiblingMut::<CfgID>::try_new(h, n).unwrap();
if let Ok(s) = SiblingMut::<CfgID>::try_new(h, n) {
transform_cfg_to_nested(&mut IdentityCfgMap::new(s));
}
node_stack.extend(h.children(n))
Expand Down Expand Up @@ -224,7 +222,7 @@ pub struct IdentityCfgMap<H> {
entry: Node,
exit: Node,
}
impl<H: HugrView> IdentityCfgMap<H> {
impl<H: RootTagged<RootHandle = CfgID>> IdentityCfgMap<H> {
/// Creates an [IdentityCfgMap] for the specified CFG
pub fn new(h: H) -> Self {
// Panic if malformed enough not to have two children
Expand Down Expand Up @@ -358,7 +356,7 @@ struct UndirectedDFSTree<T> {
}

impl<T: Copy + Clone + PartialEq + Eq + Hash> UndirectedDFSTree<T> {
pub fn new(cfg: &impl CfgNodeMap<T>) -> Self {
fn new(cfg: &impl CfgNodeMap<T>) -> Self {
//1. Traverse backwards-only from exit building bitset of reachable nodes
let mut reachable = HashSet::new();
{
Expand Down Expand Up @@ -579,6 +577,7 @@ pub(crate) mod test {
use crate::extension::PRELUDE_REGISTRY;
use crate::extension::{prelude::USIZE_T, ExtensionSet};

use crate::hugr::views::RootChecked;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::Const;
use crate::types::{FunctionType, Type};
Expand Down Expand Up @@ -624,13 +623,10 @@ pub(crate) mod test {
cfg_builder.branch(&tail, 0, &exit)?;

let mut h = cfg_builder.finish_prelude_hugr()?;

let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap();
let (entry, exit) = (entry.node(), exit.node());
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
// There's no need to use a view of a region here but we do so just to check
// that we *can* (as we'll need to for "real" module Hugr's)
let v: SiblingGraph = SiblingGraph::try_new(&h, h.root()).unwrap();
let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(v));
let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow()));
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == split)
Expand All @@ -651,7 +647,7 @@ pub(crate) mod test {
sorted([(entry, split), (merge, head), (tail, exit)]), // Two regions, conditional and then loop.
])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(&mut h));
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
h.validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(1, depth(&h, exit));
Expand Down Expand Up @@ -690,7 +686,8 @@ pub(crate) mod test {
.try_into()
.unwrap();

let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(&h));
let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap());
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == entry)
Expand Down Expand Up @@ -728,7 +725,9 @@ pub(crate) mod test {
// merge is unique predecessor of tail
let merge = h.input_neighbours(tail).exactly_one().unwrap();

let v = IdentityCfgMap::new(&h);
// There's no need to use a view of a region here but we do so just to check
// that we *can* (as we'll need to for "real" module Hugr's)
let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap());
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let IdentityCfgMap { h: _, entry, exit } = v;
let [&left, &right] = edge_classes
Expand All @@ -751,8 +750,8 @@ pub(crate) mod test {
])
);

// We could operate on the (&mut) Hugr directly here, but check that the transformation
// works on a SiblingMut (i.e. which only allows direct mutation at the top level)
// Again, there's no need for a view of a region here, but check that the
// transformation still works when we can only directly mutate the top level
let root = h.root();
let m = SiblingMut::<CfgID>::try_new(&mut h, root).unwrap();
transform_cfg_to_nested(&mut IdentityCfgMap::new(m));
Expand All @@ -779,7 +778,7 @@ pub(crate) mod test {
// Here we would like an indication that we can make two nested regions,
// but there is no edge to act as entry to a region containing just the conditional :-(.

let v = IdentityCfgMap::new(&h);
let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap());
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let IdentityCfgMap { h: _, entry, exit } = v;
// merge is unique predecessor of tail
Expand Down
8 changes: 8 additions & 0 deletions src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{check_tag, RootTagged};

/// A view of the whole Hugr.
/// (Just provides static checking of the type of the root node)
#[derive(Clone)]
pub struct RootChecked<H, Root = Node>(H, PhantomData<Root>);

impl<H: RootTagged + AsRef<Hugr>, Root: NodeHandle> RootChecked<H, Root> {
Expand Down Expand Up @@ -37,6 +38,13 @@ impl<Root> RootChecked<Hugr, Root> {
}
}

impl<Root> RootChecked<&mut Hugr, Root> {
/// Allows immutably borrowing the underlying mutable reference
pub fn borrow(&self) -> RootChecked<&Hugr, Root> {
RootChecked(&*self.0, PhantomData)
}
}

impl<H: AsRef<Hugr>, Root: NodeHandle> RootTagged for RootChecked<H, Root> {
type RootHandle = Root;
}
Expand Down

0 comments on commit f570b97

Please sign in to comment.