Skip to content

Commit 48b7279

Browse files
support cast op from FP32 to low precision (#60385)
1 parent 8b2b953 commit 48b7279

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python/paddle/distributed/passes/auto_parallel_amp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ def build_state(self):
232232
return is_train
233233

234234
def _mark_black_white_ops(self, op, ops, block):
235+
# deal auto_cast info
236+
if not op.amp_options.enable:
237+
self._op_fp16_dict[op.desc.original_id()] = False
238+
return
239+
235240
# ernie inference trick
236241
if op.type == "assign" and "array_" in op.input_arg_names[0]:
237242
self._op_fp16_dict[op.desc.original_id()] = False

python/paddle/distributed/passes/auto_parallel_fp16.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def _build_state(self):
209209
for block in self.program.blocks:
210210
self.resolute_tensor_dtype(block)
211211

212+
for block in self.program.blocks:
213+
self.resolute_cast_op(block)
214+
212215
# insert cast ops
213216
for block in self.program.blocks:
214217
self.cast_block(block)
@@ -296,6 +299,19 @@ def set_var_to_fp16(self, var_name, block):
296299
if var.dtype == core.VarDesc.VarType.FP32:
297300
var.desc.set_dtype(__target_dtype__)
298301

302+
def resolute_cast_op(self, block):
303+
"""
304+
Deal the "cast_op" from "FP32" to "FP16" or "BF16" in the model.
305+
"""
306+
for op in block.ops:
307+
if op.type == "cast":
308+
in_name = op.input('X')[0]
309+
out_name = op.output('Out')[0]
310+
in_var = block._find_var_recursive(in_name)
311+
out_var = block._find_var_recursive(out_name)
312+
op._set_attr("in_dtype", in_var.dtype)
313+
op._set_attr("out_dtype", out_var.dtype)
314+
299315
def resolute_tensor_dtype(self, block):
300316
for op in block.ops:
301317
# 'amp_options' flag has highest priority

0 commit comments

Comments
 (0)