GroCo implements group equivariant convolutions in Keras 3. It stays as close as possible to the interface of the standard convolution layers, and supports all the most common convolutional layers. This includes 2D and 3D convolutions, transposed convolutions, and pooling layers.
It is written in a fully backend-agnostic way, and trivially supports any backend that Keras supports (TensorFlow, PyTorch and JAX).
- Clone repo:
git clone https://github.com/APJansen/GroCo.git
- Navigate to the project folder:
cd groco
- If necessary install poetry:
pip install poetry
- Run
poetry install
Your working environment can be activated via poetry shell
, or to use it in a single command, prefix it with poetry run
.
More information on poetry available here.
This Colab notebook is a standalone introduction to group convolutions, meant to be read after/in parallel with the lectures, book or paper mentioned above. It does not use the implementation in GroCo, but rather derives an early version of it from scratch, going into many of the nitty gritty aspects. I hope it can be useful in parallel with the other sources.
This Colab notebook illustrates how to use GroCo by constructing a group convolutional network, training it on MNIST and comparing to a regular network. It also illustrates how to pool onto subgroups, which increases performance on MNIST. (Though not compared to the regular convolution, it is just used as a simple example but doesn't lend itself well to group convolutions as orientation matters in MNIST images.)
Convolutions are equivariant to translations, meaning if we have a convolutional layer L
, an input x
and a translation T
, then if we translate the input and then apply the layer, we obtain the same result as if we apply the layer first, and then perform the translation: L(T(x)) == T(L(x))
.
Group convolutions generalize this and are equivariant to a larger group.
All of the generalizations follow the same three step procedure:
- Interpret the existing functionality in terms of a group. Most importantly, we interpret a grid (of pixels say) not as points but rather as translations. The mapping is trivial: the grid point with indices
(i, j)
is a translation that maps any point(x, y)
to(x + i, y + j)
. - Expand the group with what's called a point group, from just translations to translations combined with other symmetries such as rotations and mirroring along an axis. The original translations together with the new point group as a whole is called a wallpaper group.
- Apply the standard functionality directly, in the same manner, to this larger group.
Here we give a quick overview of the implemented functionality. G
refers to the point group, and |G|
to its order, or the number of elements.
The largest wallpaper group on a square lattice, p4m, and all its subgroups have been implemented.
- | Convolution |
---|---|
Group Interpretation | Apply group transformations to a kernel, multiply with signal and sum |
Generalization | Before the regular convolution transform the kernel with the point group |
Implementation | GroupConv2D(group='point_group_name', ...) |
Resulting differences | kernel gets copied |G| times, number of parameters stays the same but output channel grows by factor |G| |
reference | CohenWelling2016 |
- | stride |
---|---|
Group Interpretation | subsample on a subgroup of the translations, i.e. the translations with even coordinates for stride 2 |
Generalization | subsampling onto any subgroup of the wallpaper group |
Implementation | GroupConv2D(..., stride=s, subgroup='point_subgroup_name', ...) |
Resulting differences | strides are done as usual*, and independently we can subsample on a subgroup of the point group |
reference | CohenWelling2016 |
comments | Strides are tricky in that they can cause the origin of the new, smaller grid to not coincide with the original origin and this breaks equivariance. To prevent this the default padding option is valid_equiv , which pads a minimal extra amount to prevent this. Can be turned off by setting it back to valid , and same_equiv is also possible. |
- | pooling |
---|---|
Group Interpretation | again subsample on subgroup of strides, but first aggregate on its cosets closest to the identity |
Generalization | subsampling onto any subgroup of the wallpaper group |
Implementation | GroupMaxPooling2D(group='group_name', subgroup='subgroup_name', ...) , and the same with GroupAveragePooling2D |
Resulting differences | in addition to pooling over the grid, potentially subsample on a subgroup of the point group, after aggregating on its cosets |
reference | CohenWelling2016 |
- | 3D Convolution |
---|---|
Group Interpretation | Apply group transformations to a kernel, multiply with signal and sum |
Generalization | Before the regular convolution transform the kernel with the point group |
Implementation | GroupConv2D(group='point_group_name', ...) |
Resulting differences | kernel gets copied |G| times, number of parameters stays the same but output channel grows by factor |G| |
reference | WinkelsCohen2018 |
comments | This is conceptually identical to 2d convolutions, but the groups get a lot bigger. |
- | Transposed Convolution |
---|---|
Group Interpretation | A transpose convolution can be seen as an upsampling to a larger group, of which the starting group is a subgroup (I don't think there's a group the |