Skip to content

Commit a37ca86

Browse files
committed
changes after review
1 parent ce30ebb commit a37ca86

File tree

2 files changed

+7
-41
lines changed

2 files changed

+7
-41
lines changed

paddle/fluid/operators/cast_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ class CastOp : public framework::OperatorWithKernel {
9090
int dtype_fp32 = (int)framework::proto::VarType::FP32;
9191
int dtype_bf16 = (int)framework::proto::VarType::BF16;
9292

93-
if (in_dtype != dtype_fp32 && in_dtype != dtype_bf16) return false;
94-
if (out_dtype != dtype_fp32 && out_dtype != dtype_bf16) return false;
93+
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) return false;
9594

9695
return true;
9796
};

python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -48,11 +48,11 @@ def test_check_grad(self):
4848
core.CPUPlace(), ["X"],
4949
"Out",
5050
check_dygraph=False,
51-
user_defined_grads=[self.x_bf16],
52-
user_defined_grad_outputs=[self.x_fp32])
51+
user_defined_grads=[self.inputs['X']],
52+
user_defined_grad_outputs=[self.outputs['Out']])
5353

5454

55-
class TestCastFP32ToBF16MKLDNNOp(OpTest):
55+
class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
5656
def setUp(self):
5757
self.x_fp32 = np.random.random(size=[2, 6]).astype("float32")
5858
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
@@ -66,19 +66,8 @@ def setUp(self):
6666
}
6767
self.op_type = 'cast'
6868

69-
def test_check_output(self):
70-
self.check_output(check_dygraph=False)
71-
72-
def test_check_grad(self):
73-
self.check_grad_with_place(
74-
core.CPUPlace(), ["X"],
75-
"Out",
76-
check_dygraph=False,
77-
user_defined_grads=[self.x_fp32],
78-
user_defined_grad_outputs=[self.x_bf16])
79-
8069

81-
class TestCastBF16ToBF16MKLDNNOp(OpTest):
70+
class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
8271
def setUp(self):
8372
self.x_fp32 = np.random.random(size=[6, 13]).astype("float32")
8473
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
@@ -92,19 +81,8 @@ def setUp(self):
9281
}
9382
self.op_type = 'cast'
9483

95-
def test_check_output(self):
96-
self.check_output(check_dygraph=False)
9784

98-
def test_check_grad(self):
99-
self.check_grad_with_place(
100-
core.CPUPlace(), ["X"],
101-
"Out",
102-
check_dygraph=False,
103-
user_defined_grads=[self.x_bf16],
104-
user_defined_grad_outputs=[self.x_bf16])
105-
106-
107-
class TestCastFP32ToFP32MKLDNNOp(OpTest):
85+
class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
10886
def setUp(self):
10987
self.x_fp32 = np.random.random(size=[7, 15]).astype("float32")
11088

@@ -117,17 +95,6 @@ def setUp(self):
11795
}
11896
self.op_type = 'cast'
11997

120-
def test_check_output(self):
121-
self.check_output(check_dygraph=False)
122-
123-
def test_check_grad(self):
124-
self.check_grad_with_place(
125-
core.CPUPlace(), ["X"],
126-
"Out",
127-
check_dygraph=False,
128-
user_defined_grads=[self.x_fp32],
129-
user_defined_grad_outputs=[self.x_fp32])
130-
13198

13299
if __name__ == '__main__':
133100
paddle.enable_static()

0 commit comments

Comments
 (0)