diff --git a/tvnet.py b/tvnet.py index 00fb855..cbeefb4 100644 --- a/tvnet.py +++ b/tvnet.py @@ -125,7 +125,34 @@ def forward_gradient(self, x, name): diff_y = tf.concat(axis=1, values=[diff_y_valid, last_row]) return diff_x, diff_y + + def forward_gradient_forloss(self, x, name): + assert len(x.shape) == 4 + + with tf.variable_scope('forward_gradient'): + x_ker_init = tf.constant_initializer([[-1, 1]]) + diff_x = tf.layers.conv2d(x, x.shape[-1].value, [1, 2], padding='same', + kernel_initializer=x_ker_init, use_bias=False, name=name + '_diff_x', + trainable=True) + + y_ker_init = tf.constant_initializer([[-1], [1]]) + diff_y = tf.layers.conv2d(x, x.shape[-1].value, [2, 1], padding='same', + kernel_initializer=y_ker_init, use_bias=False, name=name + '_diff_y', + trainable=True) + # refine the boundary + diff_x_valid = tf.slice(diff_x, begin=[0, 0, 0, 0], + size=[-1, x.shape[1].value, x.shape[2].value - 1, x.shape[3].value]) + last_col = tf.zeros([tf.shape(x)[0], x.shape[1].value, 1, x.shape[3].value], dtype=tf.float32) + diff_x = tf.concat(axis=2, values=[diff_x_valid, last_col]) + + diff_y_valid = tf.slice(diff_y, begin=[0, 0, 0, 0], + size=[-1, x.shape[1].value - 1, x.shape[2].value, x.shape[3].value]) + last_row = tf.zeros([tf.shape(x)[0], 1, x.shape[2].value, x.shape[3].value], dtype=tf.float32) + diff_y = tf.concat(axis=1, values=[diff_y_valid, last_row]) + + return diff_x, diff_y + def divergence(self, x, y, name): assert len(x.shape) == 4 @@ -310,8 +337,8 @@ def get_loss(self, x1, x2, max_iterations=max_iterations) # computing loss - u1x, u1y = self.forward_gradient(u1, 'u1') - u2x, u2y = self.forward_gradient(u2, 'u2') + u1x, u1y = self.forward_gradient_forloss(u1, 'u1') + u2x, u2y = self.forward_gradient_forloss(u2, 'u2') u1_flat = tf.reshape(u1, (tf.shape(x2)[0], 1, x2.shape[1].value * x2.shape[2].value))