Skip to content

Commit

Permalink
'int' object has no attribute 'item' on conv shapes calculation (issu…
Browse files Browse the repository at this point in the history
…e 96) (#101)

* make sure shape is indexed correctly
* Improved and tested convolution shape inference
* Added torch support for _index_tuple
* Updated ruff syntax

---------

Co-authored-by: Jens E. Pedersen <jensegholm@protonmail.com>
  • Loading branch information
stevenabreu7 and Jegp committed Jun 1, 2024
1 parent 2c87511 commit cea08f3
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ jobs:
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
ruff --output-format=github --select=E9,F63,F7,F82 --target-version=py37 .
ruff check . --output-format=github --select=E9,F63,F7,F82 --target-version=py37
# default set of ruff rules with GitHub Annotations
ruff --output-format=github --target-version=py37 --exclude=docs/ --exclude=paper/ .
ruff check . --output-format=github --target-version=py37 --exclude=docs/ --exclude=paper/
- name: Test with pytest
run: |
pytest
15 changes: 9 additions & 6 deletions nir/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def calculate_conv_output(
shapes = []
for i in range(ndim):
if isinstance(padding, str) and padding == "same":
shape = input_shape[i]
shape = _index_tuple(input_shape, i)
else:
shape = np.floor(
(
Expand Down Expand Up @@ -87,19 +87,22 @@ def calc_flatten_output(input_shape: Sequence[int], start_dim: int, end_dim: int
)


def _index_tuple(
tuple: Union[int, Sequence[int]], index: int
) -> Union[int, np.ndarray]:
def _index_tuple(tuple: Union[int, Sequence[int]], index: int) -> np.ndarray:
"""If the input is a tuple/array, index it.
Otherwise, return it as-is.
"""
if isinstance(tuple, np.ndarray) or isinstance(tuple, Sequence):
if isinstance(tuple, np.ndarray):
return tuple[index]
elif isinstance(tuple, Sequence):
return np.array(tuple[index])
elif isinstance(tuple, (int, np.integer)):
return np.array([tuple])
else:
raise TypeError(f"tuple must be int or np.ndarray, not {type(tuple)}")
try:
return tuple[index]
except TypeError:
raise TypeError(f"tuple must be int or np.ndarray, not {type(tuple)}")


def ensure_str(a: Union[str, bytes]) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ find={include = ["nir*"]}

[tool.ruff]
line-length = 100
per-file-ignores = {"docs/conf.py" = ["E402"]}
lint.per-file-ignores = {"docs/conf.py" = ["E402"]}
exclude = ["paper/"]
22 changes: 22 additions & 0 deletions tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,28 @@ def test_conv2d():
assert np.allclose(a.output_type["output"], np.array([3, 100, 50]))


def test_conv2d_same():
# Create a NIR Network
conv_weights = np.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
li_tau = np.array([0.9, 0.8])
li_r = np.array([1.0, 1.0])
li_v_leak = np.array([0.0, 0.0])

nir_network = nir.NIRGraph.from_list(
nir.Conv2d(
input_shape=(3, 3),
weight=conv_weights,
stride=1,
padding="same",
dilation=1,
groups=1,
bias=np.array([0.0] * 9),
),
nir.LI(li_tau, li_r, li_v_leak),
)
assert np.allclose(nir_network.nodes["conv2d"].output_type["output"], [1, 3, 3])


def test_cuba_lif():
a = np.random.randn(10, 10)
lif = nir.CubaLIF(tau_mem=a, tau_syn=a, r=a, v_leak=a, v_threshold=a)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import pytest

from nir.ir.utils import _index_tuple

import importlib

_TORCH_SPEC = importlib.util.find_spec("torch") is not None


def test_index_tuple():
assert _index_tuple(1, 0) == 1
assert _index_tuple([1, 2], 0) == 1
assert _index_tuple(np.array([1, 2]), 0) == 1
assert np.all(
np.equal(_index_tuple(np.array([[1, 2], [3, 4]]), 1), np.array([3, 4]))
)


@pytest.mark.skipif(_TORCH_SPEC is not None, reason="requires torch")
def test_index_tuple_torch():
torch = _TORCH_SPEC.loader.load_module()
assert _index_tuple(torch.tensor([1, 2]), 0) == 1
assert _index_tuple(torch.tensor([[1, 2], [3, 4]]), 1).equal(torch.tensor([3, 4]))

0 comments on commit cea08f3

Please sign in to comment.