Skip to content

Commit

Permalink
implement-batch-norm-layer (ml-explore#217)
Browse files Browse the repository at this point in the history
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
  • Loading branch information
3 people committed Dec 25, 2023
1 parent 9e6b8c9 commit a123c3c
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Layers
Linear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
Expand Down
2 changes: 1 addition & 1 deletion python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (
Expand Down
12 changes: 5 additions & 7 deletions python/mlx/nn/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class Dropout(Module):
"""Randomly zero a portion of the elements during training.
r"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
Expand Down Expand Up @@ -36,15 +36,13 @@ def __call__(self, x):


class Dropout2d(Module):
"""Apply 2D channel-wise dropout during training.
r"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
image height,``W`` is the input image width, and``C`` is the number of
input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
Expand Down
120 changes: 120 additions & 0 deletions python/mlx/nn/layers/normalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.

from typing import Tuple

import mlx.core as mx
from mlx.nn.layers.base import Module

Expand Down Expand Up @@ -178,3 +180,121 @@ def __call__(self, x):
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x


class BatchNorm(Module):
r"""Applies Batch Normalization over a 2D or 3D input.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
batch, ``C`` is the number of features or channels, and ``L`` is the
sequence length. The output has the same shape as the input. For
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
the height and width respecitvely.
For more information on Batch Normalization, see the original paper `Batch
Normalization: Accelerating Deep Network Training by Reducing Internal
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
Args:
num_features (int): The feature dimension to normalize over.
eps (float, optional): A small additive constant for numerical
stability. Default: ``1e-5``.
momentum (float, optional): The momentum for updating the running
mean and variance. Default: ``0.1``.
affine (bool, optional): If ``True``, apply a learned affine
transformation after the normalization. Default: ``True``.
track_running_stats (bool, optional): If ``True``, track the
running mean and variance. Default: ``True``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((5, 4))
>>> bn = nn.BatchNorm(num_features=4, affine=True)
>>> output = bn(x)
"""

def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()

self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.track_running_stats = track_running_stats

if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))

if self.track_running_stats:
self._running_mean = mx.zeros((num_features,))
self._running_var = mx.ones((num_features,))

def _extra_repr(self):
return (
f"{self.num_features}, eps={self.eps}, "
f"momentum={self.momentum}, affine={'weight' in self}, "
f"track_running_stats={self.track_running_stats}"
)

def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
reduction_axes = tuple(range(0, x.ndim - 1))
means = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)

if self.track_running_stats and self.training:
self._running_mean = (
1 - self.momentum
) * self._running_mean + self.momentum * means
self._running_var = (
1 - self.momentum
) * self._running_var + self.momentum * var
return means, var

def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass of BatchNorm.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""

if x.ndim < 2 or x.ndim > 4:
raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
)

if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self._running_mean, self._running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
6 changes: 3 additions & 3 deletions python/mlx/nn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"):
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the hinge loss between inputs and targets.
.. math::
Expand All @@ -311,7 +311,7 @@ def hinge_loss(
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the Huber loss between inputs and targets.
.. math::
Expand Down Expand Up @@ -345,7 +345,7 @@ def huber_loss(
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
Expand Down
137 changes: 137 additions & 0 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,143 @@ def test_group_norm(self):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))

def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)

# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
],
)
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))

# test eval mode
bn.eval()
y = bn(x)
expected_y = mx.array(
[
[-0.15984, 1.73159, -1.25456, 1.57891],
[-0.872193, -1.4281, -0.414439, -0.228678],
[0.602743, -0.30566, -0.554687, 0.139639],
[0.252199, 0.29066, -0.599572, -0.0512532],
[0.594096, -0.0334829, 2.11359, -0.151081],
]
)

self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))

# test_no_affine
bn = nn.BatchNorm(num_features=4, affine=False)
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))

# test with 3D input
mx.random.seed(42)
N = 2
L = 4
C = 5
x = mx.random.normal((N, L, C), dtype=mx.float32)

# Batch norm
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
],
[
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
],
]
)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))

x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)

def test_batch_norm_stats(self):
batch_size = 2
num_features = 4
h = 3
w = 3
momentum = 0.1

batch_norm = nn.BatchNorm(num_features)

batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)

data = mx.random.normal((batch_size, num_features))

normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=0)
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))

batch_norm = nn.BatchNorm(num_features)

batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, h, w, num_features))

normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))

def test_conv1d(self):
N = 5
L = 12
Expand Down

0 comments on commit a123c3c

Please sign in to comment.