Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions python/paddle/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GradScaler(AmpScaler):
Examples:

.. code-block:: python

import paddle

model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
Expand Down Expand Up @@ -91,7 +91,7 @@ def scale(self, var):
Examples:

.. code-block:: python

import paddle

model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
Expand Down Expand Up @@ -156,6 +156,7 @@ def is_enable(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -178,7 +179,8 @@ def is_use_dynamic_loss_scaling(self):

Examples:
.. code-block:: python


# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -202,6 +204,7 @@ def get_init_loss_scaling(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -220,11 +223,12 @@ def set_init_loss_scaling(self, new_init_loss_scaling):
Set the initial loss scaling factor by `new_init_loss_scaling`.

Args:
new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.
new_init_loss_scaling(float): The new_init_loss_scaling used to update initial loss scaling factor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下类型变更之后对原来的情况是否完全兼容?比如参数类型检查是否有相关设置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks,原本数据类型是float,初次添加注释的时候错误提交了int,这里修改为了正确的数据类型float。


Examples:
.. code-block:: python


# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -250,6 +254,7 @@ def get_incr_ratio(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -273,6 +278,7 @@ def set_incr_ratio(self, new_incr_ratio):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -298,6 +304,7 @@ def get_decr_ratio(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -321,6 +328,7 @@ def set_decr_ratio(self, new_decr_ratio):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -346,6 +354,7 @@ def get_incr_every_n_steps(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -369,6 +378,7 @@ def set_incr_every_n_steps(self, new_incr_every_n_steps):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -394,6 +404,7 @@ def get_decr_every_n_nan_or_inf(self):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -417,6 +428,7 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
Examples:
.. code-block:: python

# required: gpu,xpu
import paddle
scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
Expand All @@ -432,3 +444,59 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
"""
super(GradScaler,
self).set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)

def state_dict(self):
"""
Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

Reurns:
A dict of scaler includes:
init_loss_scaling (float, optional): The initial loss scaling factor.
incr_ratio(float, optional): The multiplier to use when increasing the loss scaling.
decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling.
incr_every_n_steps(int, optional): Increases loss scaling every n consecutive steps with finite gradients.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients.

Examples:

.. code-block:: python

# required: gpu,xpu
import paddle

scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
incr_ratio=2.0,
decr_ratio=0.5,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
use_dynamic_loss_scaling=True)
scaler_state = scaler.state_dict()
"""
return super(GradScaler, self).state_dict()

def load_state_dict(self, state_dict):
"""
Loads the scaler state.

Args:
state_dict(dict): scaler state. Should be an object returned from a call to `GradScaler.state_dict()`.

Examples:

.. code-block:: python

# required: gpu,xpu
import paddle

scaler = paddle.amp.GradScaler(enable=True,
init_loss_scaling=1024,
incr_ratio=2.0,
decr_ratio=0.5,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
use_dynamic_loss_scaling=True)
scaler_state = scaler.state_dict()
scaler.load_state_dict(scaler_state)
"""
super(GradScaler, self).load_state_dict(state_dict)
52 changes: 52 additions & 0 deletions python/paddle/fluid/dygraph/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,55 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
"""
self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf

def state_dict(self):
"""
Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

Reurns:
A dict of scaler includes:
scale (tensor): The loss scaling factor.
incr_ratio(float): The multiplier to use when increasing the loss scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
incr_count(int): The number of recent consecutive unskipped steps.
decr_count(int): The number of recent consecutive skipped steps.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
"""
return {
"scale": self._scale.numpy(),
"incr_ratio": self._incr_ratio,
"decr_ratio": self._decr_ratio,
"incr_every_n_steps": self._incr_every_n_steps,
"decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
"incr_count": self._incr_count,
"decr_count": self._decr_count,
"use_dynamic_loss_scaling": self._use_dynamic_loss_scaling
} if self._enable else {}

def load_state_dict(self, state_dict):
"""
Loads the scaler state.

Args:
state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`.
"""
if not self._enable:
return

if len(state_dict) == 0:
raise RuntimeError(
"The input state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")

self._init_loss_scaling = state_dict["scale"][0]
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32))
self._incr_ratio = state_dict["incr_ratio"]
self._decr_ratio = state_dict["decr_ratio"]
self._incr_every_n_steps = state_dict["incr_every_n_steps"]
self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
self._incr_count = state_dict["incr_count"]
self._decr_count = state_dict["decr_count"]
self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]
Loading