-
Notifications
You must be signed in to change notification settings - Fork 90
Description
I'm puzzled by the number of trainable parameters in networks using gconv2d.
The script below creates a network using gconv2ds from Z2 to C4 to C4 and counts the number of learnable parameters in the network.
For the groups C4 and D4, the result is S times larger different than what I expected, where S is the number of non-translation transformations, i.e. roto-flips (so S=4 for C4 and S=8 for D4).
Specifically, I'd expect there to be the same number of learnable parameters in a gconv2d layer as in a a normal 2D conv layer (namely n_feat_maps_in*n_feat_maps_out*kernel_size**2, when both have no biases, as is the case for this repository).
So, for C4, the number of parameters I would expect to be learnable in the example below would be 135 + 315, when it turns out to instead be 135 + 315*4. Similarly for D4, we get 135 + 315*8.
I understand how the total number of parameters should be 135 + 315*4 for C4 and 135 + 315*8 for D4, since the filters are practically speaking different (in that they have been roto-flipped).
However, I don't think that they should all be individually learnable (since the roto-flip transformations are not learnable), and I'm worried that there may be a problem in the implementation.
It could also very well be that I have misunderstood something fundamental, but isn't the whole point of gconvs related to a group G that they are equivariant to the transformations in G without an increase in the number of trainable parameters?
Finally, the test for equivariance at the end of the below script also fails. Is this related, or am I testing the wrong thing?
For the record, I'm finding the same when using the keras_gcnn Keras implementation, i.e., I get the same (and higher than expected) number of trainable parameters when using the model.summary() method of Keras.
Thank you for your time, and for this awesome work!
import numpy as np
import tensorflow as tf
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util
# Model parameters
kernel_size = 3
n_feat_maps_0 = 3
n_feat_maps_1 = 5
n_feat_maps_2 = 7
group_0 = 'Z2'
group_1 = 'C4'
group_2 = 'C4' # Not currently implemented for C4 --> D4
# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, n_feat_maps_0])
# Z2 --> C4 convolution
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input=group_0, h_output=group_1, in_channels=n_feat_maps_0, out_channels=n_feat_maps_1, ksize=kernel_size)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
# C4 --> C4 convolution
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input=group_1, h_output=group_2, in_channels=n_feat_maps_1, out_channels=n_feat_maps_2, ksize=kernel_size)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
# Compute
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
output = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, 3)})
print(output.shape) # (10, 9, 9, 28)
# Count the number of trainable parameters
print(np.sum([np.prod(v.shape) for v in tf.trainable_variables()])) # 1395 (135 + 315*4)
# Test equivariance by comparing outputs for rotated versions of same datapoint
datapoint = np.random.randn(9, 9, 3)
input = np.stack([datapoint, np.rot90(datapoint)])
output = sess.run(y, feed_dict={x: input})
print(np.allclose(output[0], np.rot90(output[1]))) # False
sess.close()