Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
_rename_arg(op, in_var.name, out_var.name)

for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(op.attr(attr_name)):
if op.has_attr(attr_name) and op.attr(attr_name) in FLOAT_TYPES:
op._set_attr(attr_name, dest_dtype)

return num_cast_ops
Expand Down Expand Up @@ -405,13 +405,18 @@ def fp16_guard():
yield


def is_float_dtype(dtype):
return (
dtype == core.VarDesc.VarType.FP32
or dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
or dtype == core.VarDesc.VarType.FP64
)
FLOAT_TYPES = {
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP64,
}

SUPPORT_FLOAT_TYPES = {
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
}


def set_var_dst_dtype(
Expand All @@ -433,7 +438,7 @@ def set_var_dst_dtype(
if var is None or var.type not in _valid_types:
continue

if is_float_dtype(var.dtype):
if var.dtype in FLOAT_TYPES:
low_precison_var_names.add(var_name)
if need_set_dtype:
var.desc.set_dtype(dtype)
Expand Down Expand Up @@ -700,6 +705,25 @@ def cast_model_to_fp16(

def need_process(op):
need_process = True

def is_support_type(name):
if not op.block._find_var_recursive(
name
): # a special case for lod_tensor_blocking_queue_0
return True
if (
op.block._var_recursive(name).type
!= core.VarDesc.VarType.LOD_TENSOR
):
return False
return op.block._var_recursive(name).dtype in SUPPORT_FLOAT_TYPES

if len(op.input_arg_names) > 0 and all(
not is_support_type(name) for name in op.input_arg_names
):
return False

# if input type of op is fp64, we just skip it.
if op.type in ["set_value"]:
# NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is
# determined by the input.dtype instead of attribute. So, here we still process it.
Expand All @@ -711,8 +735,7 @@ def need_process(op):
# output type of some operators such as fill_constant will be determined by the attribute value.
#
if not op.has_attr('in_dtype') and (
op.has_attr(attr_name)
and is_float_dtype(op.attr(attr_name))
op.has_attr(attr_name) and op.attr(attr_name) in FLOAT_TYPES
):
need_process = False

Expand Down
48 changes: 48 additions & 0 deletions test/dygraph_to_static/test_amp_fp64_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt,
)

import paddle

np.random.seed(1)


def func(x):
y = x[0:3].astype("float32")
return y


class TestAmp64Case(Dy2StTestBase):
def _run_static(self):
static_func = paddle.jit.to_static(func)
x = paddle.randn((10, 10)).astype("float64")
with paddle.amp.auto_cast(True, level="O2"):
dy_out = func(x)
st_out = static_func(x)
np.testing.assert_allclose(dy_out.numpy(), st_out.numpy())

@test_legacy_and_pt
def test_ast_to_func(self):
self._run_static()


if __name__ == '__main__':
unittest.main()