From fff2ce5fa47828e6d2033eb5bf0328ed85b861ae Mon Sep 17 00:00:00 2001 From: zhouxin Date: Wed, 13 Aug 2025 09:00:51 +0000 Subject: [PATCH] Fix test_get_autocast_dtype on FP16-unsupported device --- test/amp/test_get_autocast_dtype.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/amp/test_get_autocast_dtype.py b/test/amp/test_get_autocast_dtype.py index dfd3ea2c91cb73..ef8ef989ec24e3 100644 --- a/test/amp/test_get_autocast_dtype.py +++ b/test/amp/test_get_autocast_dtype.py @@ -44,18 +44,30 @@ def test_amp_autocast_fp16(self): self.do_test(device, "float16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_autocast_bf16(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"): self.do_test(device, "bfloat16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_autocast_false_bf16(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"): self.do_test(device, "bfloat16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_nested_context(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"):