Skip to content

Commit

Permalink
feat: Move SimpleReplace::invalidation_set to the Rewrite trait (#…
Browse files Browse the repository at this point in the history
…602)

Closes #601
  • Loading branch information
aborgna-q committed Oct 13, 2023
1 parent f570b97 commit 2c17ad6
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 16 deletions.
22 changes: 21 additions & 1 deletion src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub mod insert_identity;
pub mod outline_cfg;
pub mod simple_replace;

use crate::{Hugr, HugrView};
use crate::{Hugr, HugrView, Node};
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};

use super::HugrMut;
Expand All @@ -15,6 +15,11 @@ pub trait Rewrite {
type Error: std::error::Error;
/// The type returned on successful application of the rewrite.
type ApplyResult;
/// The node iterator returned by [`Rewrite::invalidation_set`]
type InvalidationSet<'a>: Iterator<Item = Node> + 'a
where
Self: 'a;

/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
const UNCHANGED_ON_FAILURE: bool;
Expand All @@ -33,6 +38,13 @@ pub trait Rewrite {
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
/// being preferred.
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error>;

/// Returns a set of nodes referenced by the rewrite. Modifying any of these
/// nodes will invalidate it.
///
/// Two `impl Rewrite`s can be composed if their invalidation sets are
/// disjoint.
fn invalidation_set(&self) -> Self::InvalidationSet<'_>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
Expand All @@ -45,6 +57,9 @@ pub struct Transactional<R> {
impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
type ApplyResult = R::ApplyResult;
type InvalidationSet<'a> = R::InvalidationSet<'a>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -72,4 +87,9 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
}
r
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
self.underlying.invalidation_set()
}
}
10 changes: 10 additions & 0 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Implementation of the `InsertIdentity` operation.

use std::iter;

use crate::hugr::{HugrMut, Node};
use crate::ops::{LeafOp, OpTag, OpTrait};
use crate::types::EdgeKind;
Expand Down Expand Up @@ -51,6 +53,9 @@ impl Rewrite for IdentityInsertion {
type Error = IdentityInsertionError;
/// The inserted node.
type ApplyResult = Node;
type InvalidationSet<'a> = iter::Once<Node>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
/*
Expand Down Expand Up @@ -98,6 +103,11 @@ impl Rewrite for IdentityInsertion {
.expect("Should only fail if ports don't exist.");
Ok(new_node)
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.post_node)
}
}

#[cfg(test)]
Expand Down
11 changes: 10 additions & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG
use std::collections::HashSet;
use std::collections::{hash_set, HashSet};
use std::iter;

use itertools::Itertools;
use thiserror::Error;
Expand Down Expand Up @@ -97,6 +98,9 @@ impl Rewrite for OutlineCfg {
///
/// [CFG]: OpType::CFG
type ApplyResult = (Node, Node);
type InvalidationSet<'a> = iter::Copied<hash_set::Iter<'a, Node>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
Expand Down Expand Up @@ -211,6 +215,11 @@ impl Rewrite for OutlineCfg {

Ok((new_block, cfg_node))
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
self.blocks.iter().copied()
}
}

/// Errors that can occur in expressing an OutlineCfg rewrite.
Expand Down
33 changes: 19 additions & 14 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Implementation of the `SimpleReplace` operation.

use std::collections::HashMap;
use std::collections::{hash_map, HashMap};
use std::iter::{self, Copied};
use std::slice;

use itertools::Itertools;

Expand Down Expand Up @@ -56,23 +58,18 @@ impl SimpleReplacement {
pub fn subgraph(&self) -> &SiblingSubgraph {
&self.subgraph
}

/// Returns a set of nodes referenced by the replacement. Modifying any
/// these nodes will invalidate the replacement.
///
/// Two `SimpleReplacement`s can be composed if their affected nodes are
/// disjoint.
#[inline]
pub fn invalidation_set(&self) -> impl Iterator<Item = Node> + '_ {
let subcirc = self.subgraph.nodes().iter().copied();
let out_neighs = self.nu_out.keys().map(|&(n, _)| n);
subcirc.chain(out_neighs)
}
}

type SubgraphNodesIter<'a> = Copied<slice::Iter<'a, Node>>;
type NuOutNodesIter<'a> =
iter::Map<hash_map::Keys<'a, (Node, Port), Port>, fn(&'a (Node, Port)) -> Node>;

impl Rewrite for SimpleReplacement {
type Error = SimpleReplacementError;
type ApplyResult = ();
type InvalidationSet<'a> = iter::Chain<SubgraphNodesIter<'a>, NuOutNodesIter<'a>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;

Expand Down Expand Up @@ -197,6 +194,14 @@ impl Rewrite for SimpleReplacement {
}
Ok(())
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
let subcirc = self.subgraph.nodes().iter().copied();
let get_node: fn(&(Node, Port)) -> Node = |key: &(Node, Port)| key.0;
let out_neighs = self.nu_out.keys().map(get_node);
subcirc.chain(out_neighs)
}
}

/// Error from a [`SimpleReplacement`] operation.
Expand Down Expand Up @@ -227,7 +232,7 @@ pub(in crate::hugr::rewrite) mod test {
use crate::extension::prelude::BOOL_T;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Node};
use crate::hugr::{Hugr, HugrMut, Node, Rewrite};
use crate::ops::OpTag;
use crate::ops::{OpTrait, OpType};
use crate::std_extensions::logic::test::and_op;
Expand Down

0 comments on commit 2c17ad6

Please sign in to comment.