Skip to content

Commit 47172c0

Browse files
add docs
1 parent 1c3d1a3 commit 47172c0

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

python/paddle/amp/auto_cast.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1319,11 +1319,27 @@ def decorate(
13191319
def is_autocast_enabled(device_type: PlaceLike | None = None) -> bool:
13201320
"""
13211321
Check whether auto-mixed-precision is enabled in the current context.
1322+
13221323
Args:
13231324
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
13241325
13251326
Returns:
13261327
bool: True if auto-mixed-precision is enabled, False otherwise.
1328+
1329+
Examples:
1330+
.. code-block:: python
1331+
1332+
>>> # doctest: +REQUIRES(env:GPU)
1333+
>>> # Demo1: Check if auto-mixed-precision is enabled by default
1334+
>>> import paddle
1335+
>>> paddle.device.set_device('gpu')
1336+
>>> print(paddle.is_autocast_enabled())
1337+
False
1338+
1339+
>>> # Demo2: Enable auto-mixed-precision and check again
1340+
>>> with paddle.amp.auto_cast():
1341+
... print(paddle.is_autocast_enabled())
1342+
True
13271343
"""
13281344
if in_pir_mode():
13291345
amp_attrs = core._get_amp_attrs()
@@ -1338,9 +1354,26 @@ def is_autocast_enabled(device_type: PlaceLike | None = None) -> bool:
13381354
def get_autocast_dtype(device_type: PlaceLike | None = None) -> _DTypeLiteral:
13391355
"""
13401356
Get the auto-mixed-precision dtype in the current context.
1357+
13411358
Args:
13421359
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1360+
13431361
Returns:
1344-
_DTypeLiteral: The current AMP dtype ('float16', 'bfloat16')
1362+
_DTypeLiteral: The current AMP dtype.
1363+
1364+
Examples:
1365+
.. code-block:: python
1366+
1367+
>>> # doctest: +REQUIRES(env:GPU)
1368+
>>> # Demo1: Get default auto-mixed-precision dtype
1369+
>>> import paddle
1370+
>>> paddle.device.set_device('gpu')
1371+
>>> print(paddle.get_autocast_dtype())
1372+
float32
1373+
1374+
>>> # Demo2: Enable auto-mixed-precision and get the dtype
1375+
>>> with paddle.amp.auto_cast():
1376+
... print(paddle.get_autocast_dtype())
1377+
float16
13451378
"""
13461379
return amp_global_state().amp_dtype

0 commit comments

Comments
 (0)