@@ -1319,11 +1319,27 @@ def decorate(
13191319def is_autocast_enabled (device_type : PlaceLike | None = None ) -> bool :
13201320 """
13211321 Check whether auto-mixed-precision is enabled in the current context.
1322+
13221323 Args:
13231324 device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
13241325
13251326 Returns:
13261327 bool: True if auto-mixed-precision is enabled, False otherwise.
1328+
1329+ Examples:
1330+ .. code-block:: python
1331+
1332+ >>> # doctest: +REQUIRES(env:GPU)
1333+ >>> # Demo1: Check if auto-mixed-precision is enabled by default
1334+ >>> import paddle
1335+ >>> paddle.device.set_device('gpu')
1336+ >>> print(paddle.is_autocast_enabled())
1337+ False
1338+
1339+ >>> # Demo2: Enable auto-mixed-precision and check again
1340+ >>> with paddle.amp.auto_cast():
1341+ ... print(paddle.is_autocast_enabled())
1342+ True
13271343 """
13281344 if in_pir_mode ():
13291345 amp_attrs = core ._get_amp_attrs ()
@@ -1338,9 +1354,26 @@ def is_autocast_enabled(device_type: PlaceLike | None = None) -> bool:
13381354def get_autocast_dtype (device_type : PlaceLike | None = None ) -> _DTypeLiteral :
13391355 """
13401356 Get the auto-mixed-precision dtype in the current context.
1357+
13411358 Args:
13421359 device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1360+
13431361 Returns:
1344- _DTypeLiteral: The current AMP dtype ('float16', 'bfloat16')
1362+ _DTypeLiteral: The current AMP dtype.
1363+
1364+ Examples:
1365+ .. code-block:: python
1366+
1367+ >>> # doctest: +REQUIRES(env:GPU)
1368+ >>> # Demo1: Get default auto-mixed-precision dtype
1369+ >>> import paddle
1370+ >>> paddle.device.set_device('gpu')
1371+ >>> print(paddle.get_autocast_dtype())
1372+ float32
1373+
1374+ >>> # Demo2: Enable auto-mixed-precision and get the dtype
1375+ >>> with paddle.amp.auto_cast():
1376+ ... print(paddle.get_autocast_dtype())
1377+ float16
13451378 """
13461379 return amp_global_state ().amp_dtype
0 commit comments