Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Move SimpleReplace::invalidation_set to the Rewrite trait #602

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub mod insert_identity;
pub mod outline_cfg;
pub mod simple_replace;

use crate::{Hugr, HugrView};
use crate::{Hugr, HugrView, Node};
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};

use super::HugrMut;
Expand All @@ -15,6 +15,11 @@ pub trait Rewrite {
type Error: std::error::Error;
/// The type returned on successful application of the rewrite.
type ApplyResult;
/// The node iterator returned by [`Rewrite::invalidation_set`]
type InvalidationSet<'a>: Iterator<Item = Node> + 'a
where
Self: 'a;

/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
const UNCHANGED_ON_FAILURE: bool;
Expand All @@ -33,6 +38,13 @@ pub trait Rewrite {
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
/// being preferred.
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error>;

/// Returns a set of nodes referenced by the rewrite. Modifying any of these
/// nodes will invalidate it.
///
/// Two `impl Rewrite`s can be composed if their invalidation sets are
/// disjoint.
fn invalidation_set(&self) -> Self::InvalidationSet<'_>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
Expand All @@ -45,6 +57,9 @@ pub struct Transactional<R> {
impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
type ApplyResult = R::ApplyResult;
type InvalidationSet<'a> = R::InvalidationSet<'a>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -72,4 +87,9 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
}
r
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
self.underlying.invalidation_set()
}
}
10 changes: 10 additions & 0 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Implementation of the `InsertIdentity` operation.

use std::iter;

use crate::hugr::{HugrMut, Node};
use crate::ops::{LeafOp, OpTag, OpTrait};
use crate::types::EdgeKind;
Expand Down Expand Up @@ -51,6 +53,9 @@ impl Rewrite for IdentityInsertion {
type Error = IdentityInsertionError;
/// The inserted node.
type ApplyResult = Node;
type InvalidationSet<'a> = iter::Once<Node>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
/*
Expand Down Expand Up @@ -98,6 +103,11 @@ impl Rewrite for IdentityInsertion {
.expect("Should only fail if ports don't exist.");
Ok(new_node)
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.post_node)
}
}

#[cfg(test)]
Expand Down
11 changes: 10 additions & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG
use std::collections::HashSet;
use std::collections::{hash_set, HashSet};
use std::iter;

use itertools::Itertools;
use thiserror::Error;
Expand Down Expand Up @@ -97,6 +98,9 @@ impl Rewrite for OutlineCfg {
///
/// [CFG]: OpType::CFG
type ApplyResult = (Node, Node);
type InvalidationSet<'a> = iter::Copied<hash_set::Iter<'a, Node>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
Expand Down Expand Up @@ -211,6 +215,11 @@ impl Rewrite for OutlineCfg {

Ok((new_block, cfg_node))
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
self.blocks.iter().copied()
}
}

/// Errors that can occur in expressing an OutlineCfg rewrite.
Expand Down
33 changes: 19 additions & 14 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Implementation of the `SimpleReplace` operation.

use std::collections::HashMap;
use std::collections::{hash_map, HashMap};
use std::iter::{self, Copied};
use std::slice;

use itertools::Itertools;

Expand Down Expand Up @@ -56,23 +58,18 @@ impl SimpleReplacement {
pub fn subgraph(&self) -> &SiblingSubgraph {
&self.subgraph
}

/// Returns a set of nodes referenced by the replacement. Modifying any
/// these nodes will invalidate the replacement.
///
/// Two `SimpleReplacement`s can be composed if their affected nodes are
/// disjoint.
#[inline]
pub fn invalidation_set(&self) -> impl Iterator<Item = Node> + '_ {
let subcirc = self.subgraph.nodes().iter().copied();
let out_neighs = self.nu_out.keys().map(|&(n, _)| n);
subcirc.chain(out_neighs)
}
}

type SubgraphNodesIter<'a> = Copied<slice::Iter<'a, Node>>;
type NuOutNodesIter<'a> =
iter::Map<hash_map::Keys<'a, (Node, Port), Port>, fn(&'a (Node, Port)) -> Node>;

impl Rewrite for SimpleReplacement {
type Error = SimpleReplacementError;
type ApplyResult = ();
type InvalidationSet<'a> = iter::Chain<SubgraphNodesIter<'a>, NuOutNodesIter<'a>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;

Expand Down Expand Up @@ -197,6 +194,14 @@ impl Rewrite for SimpleReplacement {
}
Ok(())
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
let subcirc = self.subgraph.nodes().iter().copied();
let get_node: fn(&(Node, Port)) -> Node = |key: &(Node, Port)| key.0;
let out_neighs = self.nu_out.keys().map(get_node);
subcirc.chain(out_neighs)
}
}

/// Error from a [`SimpleReplacement`] operation.
Expand Down Expand Up @@ -227,7 +232,7 @@ pub(in crate::hugr::rewrite) mod test {
use crate::extension::prelude::BOOL_T;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Node};
use crate::hugr::{Hugr, HugrMut, Node, Rewrite};
use crate::ops::OpTag;
use crate::ops::{OpTrait, OpType};
use crate::std_extensions::logic::test::and_op;
Expand Down