Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: NodeType constructors, adding new_auto #635

Merged
merged 7 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub(crate) mod test {
/// inference. Using DFGBuilder will default to a root node with an open
/// extension variable
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_op_with_parent(
Expand Down
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub trait Dataflow: Container {
op: impl Into<OpType>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
self.add_dataflow_node(NodeType::open_extensions(op), input_wires)
self.add_dataflow_node(NodeType::new_auto(op), input_wires)
}

/// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the
Expand Down Expand Up @@ -628,7 +628,7 @@ fn add_op_with_wires<T: Dataflow + ?Sized>(
optype: impl Into<OpType>,
inputs: Vec<Wire>,
) -> Result<(Node, usize), BuildError> {
add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs)
add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs)
}

fn add_node_with_wires<T: Dataflow + ?Sized>(
Expand Down
2 changes: 1 addition & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CFGBuilder<Hugr> {
signature: signature.clone(),
};

let base = Hugr::new(NodeType::open_extensions(cfg_op));
let base = Hugr::new(NodeType::new_open(cfg_op));
let cfg_node = base.root();
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl ConditionalBuilder<Hugr> {
extension_delta,
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let conditional_node = base.root();

Ok(ConditionalBuilder {
Expand All @@ -194,7 +194,7 @@ impl CaseBuilder<Hugr> {
let op = ops::Case {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let root = base.root();
let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?;

Expand Down
2 changes: 1 addition & 1 deletion src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl DFGBuilder<Hugr> {
let dfg_op = ops::DFG {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(dfg_op));
let base = Hugr::new(NodeType::new_open(dfg_op));
let root = base.root();
DFGBuilder::create_with_io(base, root, signature, None)
}
Expand Down
2 changes: 1 addition & 1 deletion src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
};
self.hugr_mut().replace_op(
f_node,
NodeType::pure(ops::FuncDefn {
NodeType::new_pure(ops::FuncDefn {
name,
signature: signature.clone(),
}),
Expand Down
2 changes: 1 addition & 1 deletion src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl TailLoopBuilder<Hugr> {
rest: inputs_outputs.into(),
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(tail_loop.clone()));
let base = Hugr::new(NodeType::new_open(tail_loop.clone()));
let root = base.root();
Self::create_with_io(base, root, &tail_loop)
}
Expand Down
24 changes: 9 additions & 15 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,6 @@ impl UnificationContext {
m_output,
node_type.op_signature().extension_reqs,
);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
) {
self.add_solution(m_input, ExtensionSet::new());
}
}
// We have a solution for everything!
Some(sig) => {
Expand Down Expand Up @@ -723,7 +717,7 @@ mod test {
signature: main_sig,
};

let root_node = NodeType::open_extensions(op);
let root_node = NodeType::new_open(op);
let mut hugr = Hugr::new(root_node);

let input = ops::Input::new(type_row![NAT, NAT]);
Expand Down Expand Up @@ -833,21 +827,21 @@ mod test {
// This generates a solution that causes validation to fail
// because of a missing lift node
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A)),
}));

let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
NodeType::new_pure(ops::Input {
types: type_row![NAT],
}),
)?;

let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Output {
NodeType::new_pure(ops::Output {
types: type_row![NAT],
}),
)?;
Expand Down Expand Up @@ -1049,7 +1043,7 @@ mod test {
extension_delta: rs.clone(),
};

let mut hugr = Hugr::new(NodeType::pure(op));
let mut hugr = Hugr::new(NodeType::new_pure(op));
let conditional_node = hugr.root();

let case_op = ops::Case {
Expand Down Expand Up @@ -1084,7 +1078,7 @@ mod test {
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
Expand Down Expand Up @@ -1255,7 +1249,7 @@ mod test {
let b = ExtensionSet::singleton(&B);
let c = ExtensionSet::singleton(&C);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
}));

Expand Down Expand Up @@ -1353,7 +1347,7 @@ mod test {
/// +--------------------+
#[test]
fn multi_entry() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions?
}));
let cfg = hugr.root();
Expand Down Expand Up @@ -1436,7 +1430,7 @@ mod test {
) -> Result<Hugr, Box<dyn Error>> {
let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&hugr_delta),
}));
Expand Down
21 changes: 15 additions & 6 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl NodeType {
}

/// Instantiate an OpType with no input extensions
pub fn pure(op: impl Into<OpType>) -> Self {
pub fn new_pure(op: impl Into<OpType>) -> Self {
NodeType {
op: op.into(),
input_extensions: Some(ExtensionSet::new()),
Expand All @@ -91,13 +91,24 @@ impl NodeType {

/// Instantiate an OpType with an unknown set of input extensions
/// (to be inferred later)
pub fn open_extensions(op: impl Into<OpType>) -> Self {
pub fn new_open(op: impl Into<OpType>) -> Self {
NodeType {
op: op.into(),
input_extensions: None,
}
}

/// Instantiate an [OpType] with the default set of input extensions
/// for that OpType.
pub fn new_auto(op: impl Into<OpType>) -> Self {
let op = op.into();
if OpTag::ModuleOp.is_superset(op.tag()) {
Self::new_pure(op)
} else {
Self::new_open(op)
}
}

/// Use the input extensions to calculate the concrete signature of the node
pub fn signature(&self) -> Option<Signature> {
self.input_extensions
Expand All @@ -119,9 +130,7 @@ impl NodeType {
pub fn input_extensions(&self) -> Option<&ExtensionSet> {
self.input_extensions.as_ref()
}
}

impl NodeType {
/// Gets the underlying [OpType] i.e. without any [input_extensions]
///
/// [input_extensions]: NodeType::input_extensions
Expand Down Expand Up @@ -153,7 +162,7 @@ impl OpType {

impl Default for Hugr {
fn default() -> Self {
Self::new(NodeType::pure(crate::ops::Module))
Self::new(NodeType::new_pure(crate::ops::Module))
}
}

Expand Down Expand Up @@ -239,7 +248,7 @@ impl Hugr {

/// Add a node to the graph, with the default conversion from OpType to NodeType
pub(crate) fn add_op(&mut self, op: impl Into<OpType>) -> Node {
self.add_node(NodeType::open_extensions(op))
self.add_node(NodeType::new_auto(op))
}

/// Add a node to the graph.
Expand Down
6 changes: 3 additions & 3 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait HugrMut: HugrMutInternals {
parent: Node,
op: impl Into<OpType>,
) -> Result<Node, HugrError> {
self.add_node_with_parent(parent, NodeType::open_extensions(op))
self.add_node_with_parent(parent, NodeType::new_auto(op))
}

/// Add a node to the graph with a parent in the hierarchy.
Expand Down Expand Up @@ -217,7 +217,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
}

fn add_op_before(&mut self, sibling: Node, op: impl Into<OpType>) -> Result<Node, HugrError> {
self.add_node_before(sibling, NodeType::open_extensions(op))
self.add_node_before(sibling, NodeType::new_auto(op))
}

fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result<Node, HugrError> {
Expand Down Expand Up @@ -620,7 +620,7 @@ mod test {

{
let f_in = hugr
.add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT])))
.add_node_with_parent(f, NodeType::new_pure(ops::Input::new(type_row![NAT])))
.unwrap();
let f_out = hugr
.add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT]))
Expand Down
9 changes: 3 additions & 6 deletions src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,7 @@ impl TryFrom<SerHugrV0> for Hugr {
for node_ser in nodes {
hugr.add_node_with_parent(
node_ser.parent,
match node_ser.input_extensions {
None => NodeType::open_extensions(node_ser.op),
Some(rs) => NodeType::new(node_ser.op, rs),
},
NodeType::new(node_ser.op, node_ser.input_extensions),
)?;
}

Expand Down Expand Up @@ -332,11 +329,11 @@ pub mod test {
let mut h = Hierarchy::new();
let mut op_types = UnmanagedDenseMap::new();

op_types[root] = NodeType::open_extensions(gen_optype(&g, root));
op_types[root] = NodeType::new_open(gen_optype(&g, root));

for n in [a, b, c] {
h.push_child(n, root).unwrap();
op_types[n] = NodeType::pure(gen_optype(&g, n));
op_types[n] = NodeType::new_pure(gen_optype(&g, n));
}

let hg = Hugr {
Expand Down
Loading