Skip to content

Commit de057cd

Browse files
fix place cov
1 parent 69ef801 commit de057cd

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/legacy_test/test_place_guard.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ def test_str_place_obj_nested(self):
131131
self.assertEqual(x.place, place_obj1)
132132
self.assertNotEqual(x.place, place_obj2)
133133

134+
def test_place_str_cuda(self):
135+
if (
136+
paddle.device.is_compiled_with_cuda()
137+
and not paddle.device.is_compiled_with_rocm()
138+
):
139+
with paddle.device.device_guard("gpu"):
140+
tensor_cuda = paddle.randn([3, 3], device="cuda:0")
141+
self.assertEqual(tensor_cuda.place, paddle.CUDAPlace(0))
142+
134143

135144
if __name__ == "__main__":
136145
unittest.main()

0 commit comments

Comments
 (0)