Skip to content

Commit ab62e47

Browse files
refine dtype use (#9366)
1 parent 565a6ab commit ab62e47

6 files changed

Lines changed: 7 additions & 13 deletions

File tree

slm/model_zoo/bert/static/run_glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict):
159159
reset_parameter_names.append(n)
160160
else:
161161
dtype_str = "float32"
162-
if str(p.dtype) == "VarType.FP64":
162+
if p.dtype == paddle.float64:
163163
dtype_str = "float64"
164164
reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
165165
logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names))

slm/model_zoo/bert/static/run_glue_with_sparaity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict):
159159
reset_parameter_names.append(n)
160160
else:
161161
dtype_str = "float32"
162-
if str(p.dtype) == "VarType.FP64":
162+
if p.dtype == paddle.float64:
163163
dtype_str = "float64"
164164
reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
165165
logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names))

slm/model_zoo/bert/static/run_pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def reset_program_state_dict(model, state_dict):
151151
for n, p in state_dict.items():
152152
if "layer_norm" not in p.name:
153153
dtype_str = "float32"
154-
if str(p.dtype) == "VarType.FP64":
154+
if p.dtype == paddle.float64:
155155
dtype_str = "float64"
156156
new_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
157157
return new_state_dict

slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, learning_rate, parameters, grad_clip, **config):
9898
def _add_moments_pows(self, p):
9999
acc_dtype = p.dtype
100100
if self._is_dtype_fp16_or_bf16(acc_dtype):
101-
acc_dtype = core.VarDesc.VarType.FP32
101+
acc_dtype = paddle.float32
102102
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype, device="cpu")
103103
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, device="cpu")
104104
self._add_accumulator(

slm/model_zoo/moe/dygraph/framework/group_sharded.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
3838
GroupShardedStage2,
3939
)
40-
from paddle.framework import core
4140
from paddle.incubate.distributed.models.moe.grad_clip import ClipGradForMOEByGlobalNorm
4241
from paddle.optimizer import Optimizer
4342

@@ -99,7 +98,7 @@ def _dygraph_clip(self, params_grads):
9998
params_and_grads.append((p, g))
10099
continue
101100
# TODO(wangxi): use inplace elementwise_mul
102-
clip_input = clip_var.astype("float16") if g.dtype == core.VarDesc.VarType.FP16 else clip_var
101+
clip_input = clip_var.astype("float16") if g.dtype == paddle.float16 else clip_var
103102
new_grad = paddle.multiply(x=g, y=clip_input)
104103
params_and_grads.append((p, new_grad))
105104
return params_and_grads

slm/model_zoo/moe/dygraph/run_moe_pretrain.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
3838
GroupShardedScaler,
3939
)
40-
from paddle.framework import core
4140
from paddle.incubate.distributed.models import moe
4241
from utils import get_timers, set_timers
4342
from visualdl import LogWriter
@@ -158,12 +157,8 @@ def initialize_mp_dp_parameters(model, hcg):
158157
def unscale_method(self, optimizer):
159158
if not self._enable:
160159
return
161-
if paddle.framework.use_pir_api():
162-
type_float16 = core.DataType.FLOAT16
163-
type_float32 = core.DataType.FLOAT32
164-
else:
165-
type_float16 = core.VarDesc.VarType.FP16
166-
type_float32 = core.VarDesc.VarType.FP32
160+
type_float16 = paddle.float16
161+
type_float32 = paddle.float32
167162

168163
if getattr(optimizer, "_param_groups", None) and isinstance(optimizer._param_groups[0], dict):
169164
param_grads_fp16 = []

0 commit comments

Comments
 (0)