Skip to content

Commit d330635

Browse files
fix unit test for nn.init.eye_
1 parent 1d4550e commit d330635

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

test/legacy_test/test_nn_init_function.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,16 +1041,11 @@ def test_fp16(self):
10411041
class Test_eye_(unittest.TestCase):
10421042

10431043
def check(self, tensor):
1044-
for i in range(tensor.shape[0]):
1045-
for j in range(tensor.shape[1]):
1046-
if i == j:
1047-
self.assertEqual(
1048-
tensor[i][j], 1, f"{tensor[i][j]}, {i}, {j}"
1049-
)
1050-
else:
1051-
self.assertEqual(
1052-
tensor[i][j], 0, f"{tensor[i][j]}, {i}, {j}"
1053-
)
1044+
if not isinstance(tensor, np.ndarray):
1045+
tensor = tensor.numpy()
1046+
row, col = tensor.shape
1047+
expected = np.eye(row, col)
1048+
self.assertEqual((tensor == expected).all(), True)
10541049

10551050
def test_linear_dygraph(self):
10561051
with dygraph_guard():

0 commit comments

Comments
 (0)