Skip to content
/ GroCo Public

Implementation of Group-Convolutions as Keras layers.

License

Notifications You must be signed in to change notification settings

APJansen/GroCo

Repository files navigation

Group Convolutions in Keras 3 with GroCo

Documentation Status workflow scc badge Quality Gate Status

GroCo implements group equivariant convolutions in Keras 3. It stays as close as possible to the interface of the standard convolution layers, and supports all the most common convolutional layers. This includes 2D and 3D convolutions, transposed convolutions, and pooling layers.

It is written in a fully backend-agnostic way, and trivially supports any backend that Keras supports (TensorFlow, PyTorch and JAX).

Installation

Using poetry (recommended)

  1. Clone repo: git clone https://github.com/APJansen/GroCo.git
  2. Navigate to the project folder: cd groco
  3. If necessary install poetry: pip install poetry
  4. Run poetry install

Your working environment can be activated via poetry shell, or to use it in a single command, prefix it with poetry run. More information on poetry available here.

Introduction to group convolutions

intro colab

This Colab notebook is a standalone introduction to group convolutions, meant to be read after/in parallel with the lectures, book or paper mentioned above. It does not use the implementation in GroCo, but rather derives an early version of it from scratch, going into many of the nitty gritty aspects. I hope it can be useful in parallel with the other sources.

Example notebook

example colab

This Colab notebook illustrates how to use GroCo by constructing a group convolutional network, training it on MNIST and comparing to a regular network. It also illustrates how to pool onto subgroups, which increases performance on MNIST. (Though not compared to the regular convolution, it is just used as a simple example but doesn't lend itself well to group convolutions as orientation matters in MNIST images.)

Overview

Convolutions are equivariant to translations, meaning if we have a convolutional layer L, an input x and a translation T, then if we translate the input and then apply the layer, we obtain the same result as if we apply the layer first, and then perform the translation: L(T(x)) == T(L(x)). Group convolutions generalize this and are equivariant to a larger group.

All of the generalizations follow the same three step procedure:

  1. Interpret the existing functionality in terms of a group. Most importantly, we interpret a grid (of pixels say) not as points but rather as translations. The mapping is trivial: the grid point with indices (i, j) is a translation that maps any point (x, y) to (x + i, y + j).
  2. Expand the group with what's called a point group, from just translations to translations combined with other symmetries such as rotations and mirroring along an axis. The original translations together with the new point group as a whole is called a wallpaper group.
  3. Apply the standard functionality directly, in the same manner, to this larger group.

Here we give a quick overview of the implemented functionality. G refers to the point group, and |G| to its order, or the number of elements. The largest wallpaper group on a square lattice, p4m, and all its subgroups have been implemented.

- Convolution
Group Interpretation Apply group transformations to a kernel, multiply with signal and sum
Generalization Before the regular convolution transform the kernel with the point group
Implementation GroupConv2D(group='point_group_name', ...)
Resulting differences kernel gets copied |G| times, number of parameters stays the same but output channel grows by factor |G|
reference CohenWelling2016
- stride
Group Interpretation subsample on a subgroup of the translations, i.e. the translations with even coordinates for stride 2
Generalization subsampling onto any subgroup of the wallpaper group
Implementation GroupConv2D(..., stride=s, subgroup='point_subgroup_name', ...)
Resulting differences strides are done as usual*, and independently we can subsample on a subgroup of the point group
reference CohenWelling2016
comments Strides are tricky in that they can cause the origin of the new, smaller grid to not coincide with the original origin and this breaks equivariance. To prevent this the default padding option is valid_equiv, which pads a minimal extra amount to prevent this. Can be turned off by setting it back to valid, and same_equiv is also possible.
- pooling
Group Interpretation again subsample on subgroup of strides, but first aggregate on its cosets closest to the identity
Generalization subsampling onto any subgroup of the wallpaper group
Implementation GroupMaxPooling2D(group='group_name', subgroup='subgroup_name', ...), and the same with GroupAveragePooling2D
Resulting differences in addition to pooling over the grid, potentially subsample on a subgroup of the point group, after aggregating on its cosets
reference CohenWelling2016
- 3D Convolution
Group Interpretation Apply group transformations to a kernel, multiply with signal and sum
Generalization Before the regular convolution transform the kernel with the point group
Implementation GroupConv2D(group='point_group_name', ...)
Resulting differences kernel gets copied |G| times, number of parameters stays the same but output channel grows by factor |G|
reference WinkelsCohen2018
comments This is conceptually identical to 2d convolutions, but the groups get a lot bigger.
- Transposed Convolution
Group Interpretation A transpose convolution can be seen as an upsampling to a larger group, of which the starting group is a subgroup (I don't think there's a group the