From 2ce271e3a8ff977a6b45b01c0356ab56729c88b5 Mon Sep 17 00:00:00 2001 From: Matthias Jobst Date: Tue, 26 Mar 2024 11:31:36 +0100 Subject: [PATCH] fixed NIRNode import issue #85 --- nir/ir/__init__.py | 2 ++ tests/test_ir.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/nir/ir/__init__.py b/nir/ir/__init__.py index 5606db5..ce0c09d 100644 --- a/nir/ir/__init__.py +++ b/nir/ir/__init__.py @@ -81,6 +81,8 @@ def dict2NIRNode(data_dict: Dict[str, Any]) -> NIRNode: "IF", "LI", "LIF", + # node + "NIRNode", # pooling "AvgPool2d", "SumPool2d", diff --git a/tests/test_ir.py b/tests/test_ir.py index d73813b..9b6ade4 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -9,6 +9,10 @@ def test_has_version(): assert hasattr(nir, "__version__") +def test_has_NIRNode(): + assert hasattr(nir, "NIRNode") + + def test_eq(): a = nir.Input(np.array([2, 3])) a2 = nir.Input(np.array([2, 3]))