Skip to content

Commit df7b3e8

Browse files
LRJKDBen-Louis
authored andcommitted
[Fix]Fix fp16 bug (open-mmlab#2241)
1 parent cc4645a commit df7b3e8

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

mmpose/apis/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def train_model(model,
115115

116116
# get currently existing device type
117117
device = get_device()
118+
cfg.device = device
118119

119120
# put model on gpus
120121
if distributed:
@@ -165,6 +166,8 @@ def train_model(model,
165166
else:
166167
# fp16 setting
167168
fp16_cfg = cfg.get('fp16', None)
169+
if fp16_cfg is None and cfg.get('device', None) == 'npu':
170+
fp16_cfg = dict(loss_scale='dynamic')
168171
if fp16_cfg is not None:
169172
optimizer_config = Fp16OptimizerHook(
170173
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)

0 commit comments

Comments
 (0)