-
Notifications
You must be signed in to change notification settings - Fork 90
Description
Hi Dr. Cohen - thanks so much for providing the GrouPy and gconv_experiments repos.
I was wondering about the correct way to implement a coset max-pool on the output of y in your Tensorflow example:
# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, 3])
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='Z2', h_output='D4', in_channels=3, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
h_input='D4', h_output='D4', in_channels=64, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
...
print y.shape # (10, 9, 9, 512)
My understanding is that the last dimension, 512, comes from the output of 64 channels in the last layer multiplied by 8 (the number of output transformations: 4 rotations, each one is flipped, so 8 total for D4). I assume I'd want to implement a coset maxpool on this output, so that the dimensions are (10, 9, 9, 64) before feeding it to the next layer. (Is this assumption correct)?
I'm not very familiar with Chainer and I'm having a bit of trouble analogising the Chainer code in gconv_experiments over to Tensorflow. I'd appreciate any guidance on recreating the following lines from your paper:
"Next, we replaced each convolution by a p4-convolution
(eq. 10 and 11...and added max-pooling over rotations after the last
convolution layer."
and
"We took
the Z2CNN, replaced each convolution layer by a p4-
convolution (eq. 10) followed by a coset max-pooling over
rotations. "
Is there a distinction between max-pooling over rotations, vs coset max-pooling over rotations?
My best guess would be to do something like the following - would this be correct?
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) # (10, 9, 9, 512)
y_reshaped = tf.reshape(y, [-1, 9, 9, 64, 8]) # break the flat 512 into 64 x 8
y_maxpooled = tf.reduce_max(y_reshaped, reduction_indices=[4]) # take max along the last dimension
Thank you so much!