-
Notifications
You must be signed in to change notification settings - Fork 943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement-batch-norm-layer #217
Conversation
Hey @awni, I've been thinking about how we should structure the batch normalization module. Do you think it's a good idea to have one class that covers all batch normalization types (like |
From an implementation standpoint, I think having a single @gboduljak Has an implementation in #216 |
I fully agree with the idea of a unified BatchNorm. It's a clean and versatile approach for both implementation and the API. |
@awni Here is an updated version of BN, that is general: from typing import Tuple
import mlx.core as mx
from mlx.nn.layers.base import Module
class BatchNorm(Module):
def __init__(
self,
num_features: int,
num_dims: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()
dims_dict = {
2: ((1, num_features), (0,)),
3: ((1, num_features, 1), (0, 2)),
4: ((1, num_features, 1, 1), (0, 2, 3)),
}
if num_dims not in dims_dict:
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
shape, self.reduction_axes = dims_dict[num_dims]
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = mx.ones(shape)
self.bias = mx.zeros(shape)
if self.track_running_stats:
self.running_mean = mx.zeros(shape)
self.running_var = mx.ones(shape)
def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
var = mx.var(x, axis=self.reduction_axes, keepdims=True)
if self.track_running_stats and self.training:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * means
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array):
"""
Forward pass of BatchNorm1d.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
# return x but can be used as follow: batch_size = 4
num_features = 32
num_iters = 5
input = mx.random.normal((batch_size, num_features))
bn = BatchNorm(num_features=num_features, num_dims=2)
output = bn(input) |
We can remove the num_dims parameter by updating the implementation like so. class BatchNorm(Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
if self.track_running_stats:
self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,))
def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _check_and_expand_dims(self, x: mx.array):
"""
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
Args:
x (mx.array): Input tensor.
"""
num_dims = len(x.shape)
dims_dict = {
2: ((1, self.num_features), (0,)),
3: ((1, self.num_features, 1), (0, 2)),
4: ((1, self.num_features, 1, 1), (0, 2, 3)),
}
if num_dims not in dims_dict:
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
shape, self.reduction_axes = dims_dict[num_dims]
if self.affine and self.weight.ndim != num_dims:
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
if self.track_running_stats and self.running_mean.ndim != num_dims:
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
var = mx.var(x, axis=self.reduction_axes, keepdims=True)
if self.track_running_stats and self.training:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * means
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array):
"""
Forward pass of BatchNorm1d.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""
self._check_and_expand_dims(x)
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
|
@m0saan your final suggestion looks great. Your Then we can use
We can implement some validations, e.g. axis cannot be both a reduction axis and a feature axis. |
In this discussion post, I included some ideas to test batch norm layers. I think it is important to verify we match PyTorch and/or Jax implementations. Maybe you can use these tests. I can also add them. Your tests look good as well, but we may want to test whether we are doing moving stats tracking correctly. It would be also beneficial to test that BatchNorm is behaving correctly in train/eval mode. |
Hey @gboduljak, thanks a lot for your input! I really value your ideas on making the BatchNorm class more flexible. We would like to maintain the simplicity and user-friendliness of the framework. While your proposed changes provide additional options, they may also introduce complexity that might be unnecessary for many use cases. Maybe @awni has some thoughts on this too? Regarding your point on the performance impact of using |
I will incorporate these ideas into the testing of BN. If you have additional tests to add, feel free to include them, and we can collaborate to ensure comprehensive testing. |
Could someone explain to me what is the axes in dims_dict = {
2: ((1, self.num_features), (0,)),
3: ((1, self.num_features, 1), (0, 2)),
4: ((1, self.num_features, 1, 1), (0, 2, 3)),
} Besides, if we generalise to N-dimensional BatchNorm, does dims_dict be like dims_dict = {
N : (1, self.num_features, *( 1 for I in range(3,N)) ) , tuple( i for i in range(N) if i != 1 )
} |
Depending on the BatchNorm you want to implement (e.g. 1D, 2D), you want to normalize over different axes. |
I just came up with a new idea to avoid repeatedly calling @m0saan what do you think? |
Hello @gboduljak, I apologize for the delayed response. Using |
@awni can you please review? |
Sorry for the delay in reviewing this, we were busy getting v0.0.6 out yesterday w/ quantization etc. I will get on this asap. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on trying this layer in MIMM using the mimm/scripts/train.py
to train imagenet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice! I have one request that I think will make it perfect then we merge it.
The input tensor shape is specified as (N, C) or (N, C, L), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, C, L). | ||
For three-dimensional tensors, the shape is denoted as (N, C, H, W), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great!
I think one change can make it more consistent and a lot simpler:
For our convolutions (and in general) we follow the convention that the channels are last. So inputs to convolutions are NLC
or NHWC
. We should change two thigns:
- Batch norm should also follow that convention
- Since it is following that convention it should easily broadcast with the inputs and you can remove the whole check_and_expand_dims machinery and just let broadcasting manage it (it's super cheap to expand dims at runtime so from a perf perspective it should be trivial!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, I've updated the batch norm implementation and tests to handle inputs of shape, NLC, NWHC!
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Update BatchNorm to support NLC and NHWC input formats In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over. Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
3a235af
to
a1c06b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 This looks awesome, thanks for adding it!
torch.var uses a bias correction by default whereas MLX and NumPy do not, that is why you see the slight difference. I think it is the right call for now to use the uncorrected variance in our BN as I believe PyTorch also uses an uncorrected variance in their normalization layers. |
PS @gboduljak, @dc-dc-dc, @robertmccraith thanks for the extra reviews / discussion! PS @robertmccraith I'm following mimm eagerly, keep us posted on how it's going and what else you need to get it fully operational! |
thanks @awni for your inputs! |
- Add batch normalization layer --------- Co-authored-by: Robert McCraith <mccraithrobert@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
- Add batch normalization layer --------- Co-authored-by: Robert McCraith <mccraithrobert@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
Proposed changes
Description
This pull request introduces implementation of Batch Normalization, following the specifications outlined in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
Changes Made
BatchNorm1d
that extends theModule
class.eps
(numerical stability constant),momentum
(for running mean and variance updates), andaffine
(whether to include learnable affine parameters).BatchNorm1d
module with and without learnable parameters.Usage
Notes
Please review and provide feedback.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes