From 39e00a2d1a4148679f47279924d87a9a90898e83 Mon Sep 17 00:00:00 2001 From: Bernhard Vogginger Date: Thu, 21 Mar 2024 16:20:58 +0100 Subject: [PATCH] Add AvgPool2d - fixes #83 - also applied code formatting to tests/test_ir.py --- docs/source/primitives.md | 3 +- nir/ir/__init__.py | 4 +- nir/ir/graph.py | 15 +- nir/ir/pooling.py | 13 ++ nir/serialization.py | 6 + tests/test_ir.py | 343 +++++++++++++++++++++----------------- tests/test_readwrite.py | 15 ++ 7 files changed, 245 insertions(+), 154 deletions(-) diff --git a/docs/source/primitives.md b/docs/source/primitives.md index 708c452..49e5391 100644 --- a/docs/source/primitives.md +++ b/docs/source/primitives.md @@ -18,6 +18,7 @@ NIR defines 16 fundamental primitives listed in the table below, which backends | **Leaky integrate-fire (LIF)** | $\tau, \text{R}, v_\text{leak}, v_\text{thr}$ | **LI**; **Threshold** | $\begin{cases} v-v_\text{thr} & \text{Spike} \\ v & \text{else} \end{cases}$ | | **Scale** | $s$ | $s I$ | - | | **SumPooling** | $p$ | $\sum_{j} x_j$ | | +| **AvgPooling** | $p$ | **SumPooling**; **Scale** | - | | **Threshold** | $\theta_\text{thr}$ | $H(I - \theta_\text{thr})$ | - | Each primitive is defined by their own dynamical equation, specified in the [API docs](https://nnir.readthedocs.io/en/latest/modindex.html). @@ -33,4 +34,4 @@ $$ $$ ## Format -The intermediate represenation can be stored as hdf5 file, which benefits from compression. \ No newline at end of file +The intermediate represenation can be stored as hdf5 file, which benefits from compression. diff --git a/nir/ir/__init__.py b/nir/ir/__init__.py index 41362e7..5606db5 100644 --- a/nir/ir/__init__.py +++ b/nir/ir/__init__.py @@ -6,7 +6,7 @@ from .graph import Input, NIRGraph, Output from .linear import Affine, Linear, Scale from .neuron import IF, LI, LIF, CubaLIF, I -from .pooling import SumPool2d +from .pooling import AvgPool2d, SumPool2d from .surrogate_gradient import Threshold from .typing import NIRNode @@ -34,6 +34,7 @@ "LI", "LIF", # pooling + "AvgPool2d", "SumPool2d", # surrogate_gradient "Threshold", @@ -81,6 +82,7 @@ def dict2NIRNode(data_dict: Dict[str, Any]) -> NIRNode: "LI", "LIF", # pooling + "AvgPool2d", "SumPool2d", # surrogate_gradient "Threshold", diff --git a/nir/ir/graph.py b/nir/ir/graph.py index 653441f..5e01701 100644 --- a/nir/ir/graph.py +++ b/nir/ir/graph.py @@ -7,7 +7,7 @@ from .conv import Conv1d, Conv2d from .flatten import Flatten from .node import NIRNode -from .pooling import SumPool2d +from .pooling import AvgPool2d, SumPool2d from .typing import Edges, Nodes, Types from .utils import ( calc_flatten_output, @@ -407,6 +407,19 @@ def _forward_type_inference(self, debug=True): ) post_node.output_type = {"output": output_type} + elif isinstance(post_node, AvgPool2d): + output_shape = calculate_conv_output( + pre_node.output_type["output"][1:], + post_node.padding, + 1, + post_node.kernel_size, + post_node.stride, + ) + output_type = np.array( + [post_node.input_type["input"][0], *output_shape] + ) + post_node.output_type = {"output": output_type} + elif isinstance(post_node, Flatten): print("updateing flatten output") post_node.output_type = { diff --git a/nir/ir/pooling.py b/nir/ir/pooling.py index 29813a0..1329e8f 100644 --- a/nir/ir/pooling.py +++ b/nir/ir/pooling.py @@ -16,3 +16,16 @@ class SumPool2d(NIRNode): def __post_init__(self): self.input_type = {"input": None} self.output_type = {"output": None} + + +@dataclass(eq=False) +class AvgPool2d(NIRNode): + """Average pooling layer in 2d.""" + + kernel_size: np.ndarray # (Height, Width) + stride: np.ndarray # (Height, width) + padding: np.ndarray # (Height, width) + + def __post_init__(self): + self.input_type = {"input": None} + self.output_type = {"output": None} diff --git a/nir/serialization.py b/nir/serialization.py index 5c5d553..ca50c9b 100644 --- a/nir/serialization.py +++ b/nir/serialization.py @@ -45,6 +45,12 @@ def read_node(node: Any) -> nir.typing.NIRNode: stride=node["stride"][()], padding=node["padding"][()], ) + elif node["type"][()] == b"AvgPool2d": + return nir.AvgPool2d( + kernel_size=node["kernel_size"][()], + stride=node["stride"][()], + padding=node["padding"][()], + ) elif node["type"][()] == b"Delay": return nir.Delay(delay=node["delay"][()]) elif node["type"][()] == b"Flatten": diff --git a/tests/test_ir.py b/tests/test_ir.py index eea0223..d73813b 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -331,15 +331,44 @@ def test_inputs_outputs_properties(): def test_sumpool_type_inference(): graphs = { - 'undef graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'sumpool': nir.SumPool2d( - kernel_size=np.array([2, 2]), - stride=np.array([2, 2]), - padding=np.array([0, 0]) - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'sumpool'), ('sumpool', 'output')]), + "undef graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "sumpool": nir.SumPool2d( + kernel_size=np.array([2, 2]), + stride=np.array([2, 2]), + padding=np.array([0, 0]), + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "sumpool"), ("sumpool", "output")], + ), + } + for name, graph in graphs.items(): + try: + graph._check_types() + except Exception: + pass + else: + raise AssertionError(f"type check failed for: {name}") + graph.infer_types() + assert graph._check_types(), f"type inference failed for: {name}" + + +def test_avgpool_type_inference(): + graphs = { + "undef graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "avgpool": nir.AvgPool2d( + kernel_size=np.array([2, 2]), + stride=np.array([2, 2]), + padding=np.array([0, 0]), + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "avgpool"), ("avgpool", "output")], + ), } for name, graph in graphs.items(): try: @@ -347,52 +376,49 @@ def test_sumpool_type_inference(): except Exception: pass else: - raise AssertionError(f'type check failed for: {name}') + raise AssertionError(f"type check failed for: {name}") graph.infer_types() - assert graph._check_types(), f'type inference failed for: {name}' + assert graph._check_types(), f"type inference failed for: {name}" def test_flatten_type_inference(): graphs = { - 'undef graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'flatten': nir.Flatten( - start_dim=0, - end_dim=0, - input_type=np.array([1, 64, 64]) - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'flatten'), ('flatten', 'output')]), - - 'incorrect graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'flatten': nir.Flatten( - start_dim=0, - end_dim=0, - input_type=np.array([1, 64, 64]) - ), - 'output': nir.Output(output_type=np.array([1, 61, 1])) - }, edges=[('input', 'flatten'), ('flatten', 'output')]), - - 'undef flatten.input': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'flatten': nir.Flatten( - start_dim=0, - end_dim=0, - input_type=None - ), - 'output': nir.Output(output_type=np.array([1, 61, 61])) - }, edges=[('input', 'flatten'), ('flatten', 'output')]), - - 'undef flatten.input and graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'flatten': nir.Flatten( - start_dim=0, - end_dim=0, - input_type=None - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'flatten'), ('flatten', 'output')]) + "undef graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "flatten": nir.Flatten( + start_dim=0, end_dim=0, input_type=np.array([1, 64, 64]) + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "flatten"), ("flatten", "output")], + ), + "incorrect graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "flatten": nir.Flatten( + start_dim=0, end_dim=0, input_type=np.array([1, 64, 64]) + ), + "output": nir.Output(output_type=np.array([1, 61, 1])), + }, + edges=[("input", "flatten"), ("flatten", "output")], + ), + "undef flatten.input": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "flatten": nir.Flatten(start_dim=0, end_dim=0, input_type=None), + "output": nir.Output(output_type=np.array([1, 61, 61])), + }, + edges=[("input", "flatten"), ("flatten", "output")], + ), + "undef flatten.input and graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "flatten": nir.Flatten(start_dim=0, end_dim=0, input_type=None), + "output": nir.Output(output_type=None), + }, + edges=[("input", "flatten"), ("flatten", "output")], + ), } for name, graph in graphs.items(): try: @@ -400,110 +426,125 @@ def test_flatten_type_inference(): except Exception: pass else: - raise AssertionError(f'type check failed for: {name}') + raise AssertionError(f"type check failed for: {name}") graph.infer_types() - assert graph._check_types(), f'type inference failed for: {name}' + assert graph._check_types(), f"type inference failed for: {name}" def test_conv_type_inference(): graphs = { - 'undef graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'conv': nir.Conv2d( - input_shape=(64, 64), - weight=np.zeros((1, 1, 4, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'incorrect graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'conv': nir.Conv2d( - input_shape=(64, 64), - weight=np.zeros((1, 1, 4, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=np.array([1, 61, 1])) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'undef conv.input': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'conv': nir.Conv2d( - input_shape=None, - weight=np.zeros((1, 1, 4, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=np.array([1, 61, 61])) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'undef conv.input and graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64, 64])), - 'conv': nir.Conv2d( - input_shape=None, - weight=np.zeros((1, 1, 4, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'Conv1d undef graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64])), - 'conv': nir.Conv1d( - input_shape=64, - weight=np.zeros((1, 1, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'Conv1d incorrect graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64])), - 'conv': nir.Conv1d( - input_shape=64, - weight=np.zeros((1, 1, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=np.array([1, 3])) - }, edges=[('input', 'conv'), ('conv', 'output')]), - - 'Conv1d undef conv.input and graph output': nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=np.array([1, 64])), - 'conv': nir.Conv1d( - input_shape=None, - weight=np.zeros((1, 1, 4)), - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None - ), - 'output': nir.Output(output_type=None) - }, edges=[('input', 'conv'), ('conv', 'output')]), + "undef graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "conv": nir.Conv2d( + input_shape=(64, 64), + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "incorrect graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "conv": nir.Conv2d( + input_shape=(64, 64), + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=np.array([1, 61, 1])), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "undef conv.input": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "conv": nir.Conv2d( + input_shape=None, + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=np.array([1, 61, 61])), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "undef conv.input and graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64, 64])), + "conv": nir.Conv2d( + input_shape=None, + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "Conv1d undef graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64])), + "conv": nir.Conv1d( + input_shape=64, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "Conv1d incorrect graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64])), + "conv": nir.Conv1d( + input_shape=64, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=np.array([1, 3])), + }, + edges=[("input", "conv"), ("conv", "output")], + ), + "Conv1d undef conv.input and graph output": nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=np.array([1, 64])), + "conv": nir.Conv1d( + input_shape=None, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + ), + "output": nir.Output(output_type=None), + }, + edges=[("input", "conv"), ("conv", "output")], + ), } for name, graph in graphs.items(): try: @@ -512,6 +553,6 @@ def test_conv_type_inference(): except Exception: pass else: - raise AssertionError(f'type check failed for: {name}') + raise AssertionError(f"type check failed for: {name}") graph.infer_types() - assert graph._check_types(), f'type inference failed for: {name}' + assert graph._check_types(), f"type inference failed for: {name}" diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index a9dcdec..4ef1c3a 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -207,3 +207,18 @@ def test_sum_pool_2d(): ] ) factory_test_graph(ir) + + +def test_avg_pool_2d(): + ir = nir.NIRGraph.from_list( + [ + nir.Input(input_type=np.array([2, 2, 10, 10])), + nir.AvgPool2d( + kernel_size=np.array([2, 2]), + stride=np.array([1, 1]), + padding=np.ndarray([0, 0]), + ), + nir.Output(output_type=np.array([2, 2, 5, 5])), + ] + ) + factory_test_graph(ir)