Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata to missing nodes #112

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.277
rev: v0.6.0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to version 0.6.0 to have the same behaviour in pre-commit and during the github-actions-build.

hooks:
- id: ruff

Expand All @@ -21,4 +21,4 @@ repos:
rev: v1.7.5
hooks:
- id: docformatter
args: [--in-place, --black, --wrap-summaries=88, --wrap-descriptions=88]
args: [--in-place, --black, --wrap-summaries=88, --wrap-descriptions=88]
1 change: 1 addition & 0 deletions nir/ir/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Conv2d(NIRNode):
dilation: Union[int, Tuple[int, int]] # Dilation
groups: int # Groups
bias: np.ndarray # Bias C_out
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.padding, str) and self.padding not in ["same", "valid"]:
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class Input(NIRNode):
# Shape of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
input_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = parse_shape_argument(self.input_type, "input")
Expand Down Expand Up @@ -479,6 +480,7 @@ class Output(NIRNode):
# Type of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
output_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.output_type = parse_shape_argument(self.output_type, "output")
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Linear(NIRNode):
"""

weight: np.ndarray # Weight term
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
Expand All @@ -69,6 +70,7 @@ class Scale(NIRNode):
"""

scale: np.ndarray # Scaling factor
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": np.array(self.scale.shape)}
Expand Down
1 change: 1 addition & 0 deletions nir/ir/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AvgPool2d(NIRNode):
kernel_size: np.ndarray # (Height, Width)
stride: np.ndarray # (Height, width)
padding: np.ndarray # (Height, width)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": None}
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ find={include = ["nir*"]}
line-length = 100
lint.per-file-ignores = {"docs/conf.py" = ["E402"]}
exclude = ["paper/"]
extend-exclude = ["*.ipynb"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With version 0.6.0 of ruff, all notebooks are checked by default in ruff.
This lead to errors when checking the spyx example notebook: https://github.com/neuromorphs/NIR/actions/runs/10415679589.

To restore the previous behaviour, I excluded notebooks again.


20 changes: 11 additions & 9 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tempfile
import inspect
import sys
import tempfile

import numpy as np

import nir
from tests import mock_affine, mock_conv
from tests import mock_affine, mock_conv, mock_linear

ALL_NODES = []
for name, obj in inspect.getmembers(sys.modules["nir.ir"]):
Expand Down Expand Up @@ -47,7 +47,7 @@ def factory_test_graph(ir: nir.NIRGraph):
assert_equivalence(ir, ir2)


def factory_test_metadata(node):
def factory_test_metadata(ir: nir.NIRGraph):
def compare_dicts(d1, d2):
for k, v in d1.items():
if isinstance(v, np.ndarray):
Expand All @@ -58,12 +58,14 @@ def compare_dicts(d1, d2):
assert v == d2[k]

metadata = {"some": "metadata", "with": 2, "data": np.array([1, 2, 3])}
node.metadata = metadata
compare_dicts(node.metadata, metadata)
for node in ir.nodes.values():
node.metadata = metadata
compare_dicts(node.metadata, metadata)
tmp = tempfile.mktemp()
nir.write(tmp, node)
node2 = nir.read(tmp)
compare_dicts(node2.metadata, metadata)
nir.write(tmp, ir)
ir2 = nir.read(tmp)
for node in ir2.nodes.values():
compare_dicts(node.metadata, metadata)


def test_simple():
Expand Down Expand Up @@ -146,7 +148,7 @@ def test_linear():
tau = np.array([1, 1, 1])
r = np.array([1, 1, 1])
v_leak = np.array([1, 1, 1])
ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak))
ir = nir.NIRGraph.from_list(mock_linear(2, 2), nir.LI(tau, r, v_leak))
factory_test_graph(ir)
factory_test_metadata(ir)

Expand Down
Loading