Skip to content

Commit d219bd2

Browse files
committed
Revise code based on GCA's review
1 parent 3ecedcf commit d219bd2

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

tensorflow_quantum/core/ops/tfq_utility_ops_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,10 @@ def test_padded_to_ragged(self, padded_array):
167167
def test_padded_to_ragged2d(self, padded_array):
168168
"""Test for padded_to_ragged2d utility."""
169169
tensor_arr = tf.constant(padded_array, dtype=tf.float32)
170-
col_mask = tf.abs(tensor_arr[:, 0]) < 1.1
171-
masked = tf.ragged.boolean_mask(tensor_arr, col_mask)
172-
mask = tf.abs(masked) < 1.1
173-
expected = tf.ragged.boolean_mask(masked, mask)
174-
170+
row_mask = tf.abs(tensor_arr[:, :, 0]) < 1.1
171+
masked_rows = tf.ragged.boolean_mask(tensor_arr, row_mask)
172+
element_mask = tf.abs(masked_rows) < 1.1
173+
expected = tf.ragged.boolean_mask(masked_rows, element_mask)
175174
actual = tfq_utility_ops.padded_to_ragged2d(
176175
np.array(padded_array, dtype=float))
177176
self.assertAllEqual(expected, actual)

0 commit comments

Comments
 (0)