Skip to content

Tensorflow 2 port for GrouPy #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# TF2-GrouPy

A Tensorflow 2 port for GrouPy. Best used alongside [tf2-keras-gcnn](https://github.com/neel-dey/tf2-keras-gcnn). Note: this repo includes [group-to-Z2 indexing from nom](https://github.com/nom/GrouPy/commit/7b1128b6cb0d33e5733667f8e07490ea1d44a7dc).

Original README follows below with a tf2-compatible minimal-working-example.

### Note: If you are looking for a PyTorch implementation, please have a look at the pull requests by Jorn Peters and Adam Bielski (https://github.com/tscohen/GrouPy/pulls).

Expand Down Expand Up @@ -41,34 +46,37 @@ $ nosetests -v

### TensorFlow

```
```python
import numpy as np
import tensorflow as tf

tf.compat.v1.disable_eager_execution()

from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util

# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, 3])
x = tf.compat.v1.placeholder(tf.float32, [None, 9, 9, 3])

gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='Z2', h_output='D4', in_channels=3, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
w = tf.Variable(tf.random.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='D4', h_output='D4', in_channels=64, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
w = tf.Variable(tf.random.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

# Compute
init = tf.global_variables_initializer()
sess = tf.Session()
init = tf.compat.v1.global_variables_initializer()
sess = tf.compat.v1.Session()
sess.run(init)
y = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, 3)})
sess.close()

print y.shape # (10, 9, 9, 512)
print(y.shape) # (10, 9, 9, 512)
```

### Chainer
Expand Down Expand Up @@ -157,4 +165,4 @@ These subclasses can easily be tested against the group axioms and other mathema

## References

1. <a name="gcnn"></a> T.S. Cohen, M. Welling, [Group Equivariant Convolutional Networks](http://www.jmlr.org/proceedings/papers/v48/cohenc16.pdf). Proceedings of the International Conference on Machine Learning (ICML), 2016.
1. <a name="gcnn"></a> T.S. Cohen, M. Welling, [Group Equivariant Convolutional Networks](http://www.jmlr.org/proceedings/papers/v48/cohenc16.pdf). Proceedings of the International Conference on Machine Learning (ICML), 2016.
22 changes: 16 additions & 6 deletions groupy/gconv/tensorflow_gconv/splitgconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
Tensorflow implementation of the group convolution.
This function has the same interface as the standard convolution nn.conv2d, except for two new parameters,
gconv_indices and gconv_shape_info. These can be obtained from gconv2d_util(), and are described below

:param input: a tensor with (batch, height, width, in channels) axes.
:param filter: a tensor with (ksize, ksize, in channels * in transformations, out channels) axes.
The shape for filter can be obtained from gconv2d_util().
Expand All @@ -37,7 +36,7 @@ def gconv2d(input, filter, strides, padding, gconv_indices, gconv_shape_info,
transformed_filter = transform_filter_2d_nhwc(w=filter, flat_indices=gconv_indices, shape_info=gconv_shape_info)

# Convolve input with transformed filters
conv = tf.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
conv = tf.compat.v1.nn.conv2d(input=input, filter=transformed_filter, strides=strides, padding=padding,
use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format, name=name)

return conv
Expand All @@ -50,9 +49,8 @@ def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
1) an array of indices used in the filter transformation step of gconv2d
2) shape information required by gconv2d
5) the shape of the filter tensor to be allocated and passed to gconv2d

:param h_input: one of ('Z2', 'C4', 'D4'). Use 'Z2' for the first layer. Use 'C4' or 'D4' for later layers.
:param h_output: one of ('C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
:param h_output: one of ('Z2', 'C4', 'D4'). What kind of transformations to use (rotations or roto-reflections).
The choice of h_output of one layer should equal h_input of the next layer.
:param in_channels: the number of input channels. Note: this refers to the number of (3D) channels on the group.
The number of 2D channels will be 1, 4, or 8 times larger, depending the value of h_input.
Expand All @@ -78,10 +76,22 @@ def gconv2d_util(h_input, h_output, in_channels, out_channels, ksize):
gconv_indices = flatten_indices(make_d4_p4m_indices(ksize=ksize))
nti = 8
nto = 8
elif h_input == 'D4' and h_output == 'Z2':
gconv_indices = flatten_indices(make_d4_z2_indices(ksize=ksize))
nti = 8
nto = 1
elif h_input == 'C4' and h_output == 'Z2':
gconv_indices = flatten_indices(make_c4_z2_indices(ksize=ksize))
nti = 4
nto = 1
else:
raise ValueError('Unknown (h_input, h_output) pair:' + str((h_input, h_output)))

w_shape = (ksize, ksize, in_channels * nti, out_channels)
if h_output == 'Z2':
w_shape = (ksize, ksize, in_channels, out_channels)
else:
w_shape = (ksize, ksize, in_channels * nti, out_channels)

gconv_shape_info = (out_channels, nto, in_channels, nti, ksize)
return gconv_indices, gconv_shape_info, w_shape

Expand All @@ -94,7 +104,6 @@ def gconv2d_addbias(input, bias, nti=8):
A G-feature map usually consists of a number (e.g. 4 or 8) adjacent channels.
This function will add a single bias vector to a stack of feature maps that has e.g. 4 or 8 times more 2D channels
than G-channels, by replicating the bias across adjacent groups of 2D channels.

:param input: tensor of shape (n, h, w, ni * nti), where n is the batch dimension, (h, w) are the height and width,
ni is the number of input G-channels, and nti is the number of transformations in H.
:param bias: tensor of shape (ni,)
Expand All @@ -103,3 +112,4 @@ def gconv2d_addbias(input, bias, nti=8):
"""
# input = tf.reshape(input, ())
pass # TODO