Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,26 @@ def forward(self, x):
# imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad()

def _get_amp_optimizer(self):
# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break

if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer

assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
return amp_optimizer

def get_loss_scaling(self):
amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.get_loss_scaling()

def amp_init(self,
place,
scope=None,
Expand Down Expand Up @@ -1101,21 +1121,7 @@ def run_example_code():
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""

# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break

if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer

assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."

amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)

def _final_strategy(self):
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/contrib/mixed_precision/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _set_distributed(self, flag):
def get_loss_scaling(self):
"""Return the real-time loss scaling factor.
"""
assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个敬语Please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr一起改

return self._loss_scaling

def get_scaled_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fleet_amp_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def test_fleet_amp_init(self):
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(cost)

loss_scale = optimizer.get_loss_scaling()

place = paddle.CUDAPlace(0)

exe = paddle.static.Executor(place)
Expand Down