diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4c7d4aa79..5ef45d60d 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -20,6 +20,7 @@ Layers Linear Conv1d Conv2d + BatchNorm LayerNorm RMSNorm GroupNorm diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index d54e45f6d..5ac82356a 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -36,7 +36,7 @@ from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm +from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 14c5cb15e..caa7a6452 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -5,7 +5,7 @@ class Dropout(Module): - """Randomly zero a portion of the elements during training. + r"""Randomly zero a portion of the elements during training. The remaining elements are multiplied with :math:`\frac{1}{1-p}` where :math:`p` is the probability of zeroing an element. This is done so the @@ -36,15 +36,13 @@ def __call__(self, x): class Dropout2d(Module): - """Apply 2D channel-wise dropout during training. + r"""Apply 2D channel-wise dropout during training. Randomly zero out entire channels independently with probability :math:`p`. This layer expects the channels to be last, i.e. the input shape should be - ``NWHC`` or ``WHC`` where: - - ``N`` is the batch dimension - - ``H`` is the input image height - - ``W`` is the input image width - - ``C`` is the number of input channels + ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input + image height,``W`` is the input image width, and``C`` is the number of + input channels The remaining channels are scaled by :math:`\frac{1}{1-p}` to maintain the expected value of each element. Unlike traditional dropout, diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 6de377cda..9cd578fb2 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from typing import Tuple + import mlx.core as mx from mlx.nn.layers.base import Module @@ -178,3 +180,121 @@ def __call__(self, x): ) x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x + + +class BatchNorm(Module): + r"""Applies Batch Normalization over a 2D or 3D input. + + Computes + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. + + The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the + batch, ``C`` is the number of features or channels, and ``L`` is the + sequence length. The output has the same shape as the input. For + four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are + the height and width respecitvely. + + For more information on Batch Normalization, see the original paper `Batch + Normalization: Accelerating Deep Network Training by Reducing Internal + Covariate Shift `_. + + Args: + num_features (int): The feature dimension to normalize over. + eps (float, optional): A small additive constant for numerical + stability. Default: ``1e-5``. + momentum (float, optional): The momentum for updating the running + mean and variance. Default: ``0.1``. + affine (bool, optional): If ``True``, apply a learned affine + transformation after the normalization. Default: ``True``. + track_running_stats (bool, optional): If ``True``, track the + running mean and variance. Default: ``True``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal((5, 4)) + >>> bn = nn.BatchNorm(num_features=4, affine=True) + >>> output = bn(x) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + + if affine: + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + + if self.track_running_stats: + self._running_mean = mx.zeros((num_features,)) + self._running_var = mx.ones((num_features,)) + + def _extra_repr(self): + return ( + f"{self.num_features}, eps={self.eps}, " + f"momentum={self.momentum}, affine={'weight' in self}, " + f"track_running_stats={self.track_running_stats}" + ) + + def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: + """ + Calculate the mean and variance of the input tensor. + + Args: + x (mx.array): Input tensor. + + Returns: + tuple: Tuple containing mean and variance. + """ + reduction_axes = tuple(range(0, x.ndim - 1)) + means = mx.mean(x, axis=reduction_axes, keepdims=True) + var = mx.var(x, axis=reduction_axes, keepdims=True) + + if self.track_running_stats and self.training: + self._running_mean = ( + 1 - self.momentum + ) * self._running_mean + self.momentum * means + self._running_var = ( + 1 - self.momentum + ) * self._running_var + self.momentum * var + return means, var + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass of BatchNorm. + + Args: + x (mx.array): Input tensor. + + Returns: + mx.array: Output tensor. + """ + + if x.ndim < 2 or x.ndim > 4: + raise ValueError( + f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" + ) + + if self.training or not self.track_running_stats: + means, var = self._calc_stats(x) + else: + means, var = self._running_mean, self._running_var + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index cfb6ffa15..91316fd04 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"): def hinge_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the hinge loss between inputs and targets. .. math:: @@ -311,7 +311,7 @@ def hinge_loss( def huber_loss( inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the Huber loss between inputs and targets. .. math:: @@ -345,7 +345,7 @@ def huber_loss( def log_cosh_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the log cosh loss between inputs and targets. Logcosh acts like L2 loss for small errors, ensuring stable gradients, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index cc56bc430..2cfac4475 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -320,6 +320,143 @@ def test_group_norm(self): self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + def test_batch_norm(self): + mx.random.seed(42) + x = mx.random.normal((5, 4), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=4, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ], + ) + expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778]) + expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + # test eval mode + bn.eval() + y = bn(x) + expected_y = mx.array( + [ + [-0.15984, 1.73159, -1.25456, 1.57891], + [-0.872193, -1.4281, -0.414439, -0.228678], + [0.602743, -0.30566, -0.554687, 0.139639], + [0.252199, 0.29066, -0.599572, -0.0512532], + [0.594096, -0.0334829, 2.11359, -0.151081], + ] + ) + + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test_no_affine + bn = nn.BatchNorm(num_features=4, affine=False) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ] + ) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test with 3D input + mx.random.seed(42) + N = 2 + L = 4 + C = 5 + x = mx.random.normal((N, L, C), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=C, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + self.assertTrue(x.shape == y.shape) + expected_y = mx.array( + [ + [ + [-0.335754, 0.342054, 1.02653, 0.628588, -1.63899], + [1.92092, 0.432319, 0.343043, 1.95489, 1.0696], + [-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284], + [0.459206, -0.684822, -0.706354, -0.271531, 0.566341], + ], + [ + [-0.921179, 0.684951, -0.77466, -0.490372, -0.247032], + [1.10839, -2.13179, 0.628924, -1.62639, -0.539708], + [-0.348943, 0.412194, -2.03818, 0.524972, 1.64568], + [-1.02889, -0.421, 0.652127, -0.740079, 0.0313996], + ], + ] + ) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + expected_mean = mx.array( + [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] + ) + expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) + with self.assertRaises(ValueError): + y = bn(x) + + def test_batch_norm_stats(self): + batch_size = 2 + num_features = 4 + h = 3 + w = 3 + momentum = 0.1 + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + + data = mx.random.normal((batch_size, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=0) + variances = np.var(np_data, axis=0) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + data = mx.random.normal((batch_size, h, w, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=(0, 1, 2)) + variances = np.var(np_data, axis=(0, 1, 2)) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12