diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 062e33f26610cc..facef32fa3b5c4 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -15,7 +15,13 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -253,15 +259,78 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker { DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer, "Bias"); +class LayerNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // get inputs + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor mean = this->GetSingleForwardOutput("Mean"); + paddle::Tensor var = this->GetSingleForwardOutput("Variance"); + paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); + paddle::optional scale = + this->GetOptionalSingleForwardInput("Scale"); + paddle::optional bias = + this->GetOptionalSingleForwardInput("Bias"); + + // get Attrs + auto epsilon = this->Attr("epsilon"); + auto begin_norm_axis = this->Attr("begin_norm_axis"); + + // get outputs + paddle::Tensor x_grad = this->GetSingleInputGrad("X"); + paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); + paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); + + auto dx_ptr = this->GetOutputPtr(&x_grad); + std::string dx_name = this->GetOutputName(x_grad); + auto dscale_ptr = this->GetOutputPtr(&scale_grad); + std::string dscale_name = this->GetOutputName(scale_grad); + auto dbias_ptr = this->GetOutputPtr(&bias_grad); + std::string dbias_name = this->GetOutputName(bias_grad); + + VLOG(6) << "Runing layer_norm_grad composite func"; + prim::layer_norm_grad(x, + scale, + bias, + mean, + var, + y_grad, + epsilon, + begin_norm_axis, + dx_ptr, + dscale_ptr, + dbias_ptr); + + this->RecoverOutputName(x_grad, dx_name); + this->RecoverOutputName(scale_grad, dscale_name); + this->RecoverOutputName(bias_grad, dbias_name); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(layer_norm, + LayerNormInferShapeFunctor, + PD_INFER_META(phi::LayerNormInferMeta)); + REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, ops::LayerNormGradOpMaker, - ops::LayerNormGradOpMaker); + ops::LayerNormGradOpMaker, + ops::LayerNormCompositeGradOpMaker, + LayerNormInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(layer_norm_grad, + LayerNormGradInferShapeFunctor, + PD_INFER_META(phi::LayerNormGradInferMeta)); + REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp, - ops::LayerNormGradNoNeedBufferVarInferer); + ops::LayerNormGradNoNeedBufferVarInferer, + LayerNormGradInferShapeFunctor); diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 529d024b8b8a08..5368c45f814174 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -29,6 +29,7 @@ - tile - transpose - pad +- sqrt - cumsum - put_along_axis - greater_than diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 304102b733b8c9..a36e311f3df94c 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -877,6 +877,101 @@ void slice_grad(const Tensor& input, } } +template +void layer_norm_grad(const Tensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const Tensor& mean, + const Tensor& variance, + const Tensor& out_grad, + float epsilon, + int begin_norm_axis, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + auto x_dims = x.dims(); + auto shape_1 = 1; // front part + auto shape_2 = 1; // back part + for (int i = 0; i < begin_norm_axis; ++i) { + shape_1 *= x_dims[i]; + } + for (int i = begin_norm_axis; i < x.dims().size(); ++i) { + shape_2 *= x_dims[i]; + } + auto scale_ptr = scale.get_ptr(); + auto bias_ptr = bias.get_ptr(); + + // cast dtype to float32 if dtype =float16 + Tensor x_cast = x; + Tensor out_grad_cast = out_grad; + Tensor scale_cast; + if (scale_ptr) { + scale_cast = reshape(*scale_ptr, std::vector({1, shape_2})); + } + if (x.dtype() == phi::DataType::FLOAT16) { + x_cast = cast(x, phi::DataType::FLOAT32); + out_grad_cast = cast(out_grad, phi::DataType::FLOAT32); + if (scale_ptr) { + scale_cast = cast(scale_cast, phi::DataType::FLOAT32); + } + } + + x_cast = reshape(x_cast, std::vector({shape_1, shape_2})); + out_grad_cast = + reshape(out_grad_cast, std::vector({shape_1, shape_2})); + auto mean_ = reshape(mean, std::vector({shape_1, 1})); + auto variance_ = reshape(variance, std::vector({shape_1, 1})); + if (bias_grad) { + if (bias_ptr) { + auto bias_grad_tmp = + out_grad_cast.sum(std::vector({0}), x_cast.dtype(), true); + bias_grad_tmp = reshape(bias_grad_tmp, bias_ptr->shape()); + set_output(bias_grad_tmp, bias_grad); + } else { + bias_grad = nullptr; + } + } + auto x_sub_mean = x_cast - mean_; + auto tmp = (1.0 / (variance_ + epsilon)); + auto sqrt_var_1 = sqrt(tmp); + if (scale_grad) { + if (scale_ptr) { + auto scale_grad_tmp = + (x_sub_mean * sqrt_var_1 * out_grad_cast) + .sum(std::vector({0}), x_cast.dtype(), true); + scale_grad_tmp = reshape(scale_grad_tmp, scale_ptr->shape()); + set_output(scale_grad_tmp, scale_grad); + } else { + scale_grad = nullptr; + } + } + + if (x_grad) { + if (!scale_ptr) { + scale_cast = + full(std::vector({1, shape_2}), 1.0, x_cast.dtype()); + } + auto out_grad_scale = out_grad_cast * scale_cast; + auto dx_end = (sqrt_var_1 * out_grad_scale); + auto d_mean_0 = + (-dx_end).sum(std::vector({1}), x_cast.dtype(), true); + auto d_mean = (1.0 / shape_2) * d_mean_0; + auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale) + .sum(std::vector({1}), x_cast.dtype(), true); + auto d_std_2 = (1.0 / shape_2) * sqrt_var_1; + d_std_2 = reshape(d_std_2, std::vector({shape_1, 1})); + d_std_2 = d_std_2 * x_sub_mean; + auto d_std = d_std_1 * d_std_2; + + auto x_grad_tmp = dx_end + d_mean + d_std; + x_grad_tmp = reshape(x_grad_tmp, phi::vectorize(x.dims())); + if (x.dtype() == phi::DataType::FLOAT16) { + x_grad_tmp = cast(x_grad_tmp, x.dtype()); + } + set_output(x_grad_tmp, x_grad); + } +} + template void cumsum_grad(const Tensor& x, const Tensor& out_grad, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index a5ea461530c525..38a6ca3155de85 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -629,6 +629,7 @@ kernel : func : layer_norm_grad data_type : out_grad + composite : layer_norm_grad(x, scale, bias, mean,varience, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad) no_need_buffer : bias optional : scale, bias diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c1bbc0cff2dd45..7d2d2133467b9e 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -574,14 +574,23 @@ void LayerNormInferMeta(const MetaTensor& x, right)); } + phi::DataType x_dtype = x.dtype(); out->set_dims(x_dim); + out->set_dtype(x_dtype); + out->share_lod(x); + + phi::DataType param_type = + (x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16) + ? phi::DataType::FLOAT32 + : x_dtype; if (mean) { mean->set_dims({left}); + mean->set_dtype(param_type); } if (variance) { variance->set_dims({left}); + variance->set_dtype(param_type); } - out->share_lod(x); } void LayerNormGradInferMeta(const MetaTensor& x, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py index f4d59f1a1552f9..745408bf7dd924 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py @@ -237,10 +237,12 @@ def test_train(self): def test_train_composite(self): core._set_prim_backward_enabled(True) + # core._add_skip_comp_ops("layer_norm") static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader ) core._set_prim_backward_enabled(False) + # core._add_skip_comp_ops("layer_norm") dygraph_loss, dygraph_ppl = self.train_dygraph( self.bert_config, self.data_reader ) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py index 1c85e6e46d0131..584bfdc7aee71e 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py @@ -29,13 +29,19 @@ "float64": {"rtol": 1e-11, "atol": 1e-11}, } +TOLERANCE_COMP_GRAD = { + "float32": {"rtol": 1e-3, "atol": 1e-3}, + "float16": {"rtol": 1e-3, "atol": 1e-3}, # amp +} + def generate_data(shape1, shape2, shape3, dtype="float32"): - np.random.seed(200) + np.random.seed(12) np_data1 = np.random.random(shape1).astype(dtype) np_data2 = np.random.random(shape2).astype(dtype) np_data3 = np.random.random(shape3).astype(dtype) - return np_data1, np_data2, np_data3 + np_data4 = np.ones_like(np_data1).astype(dtype) + return np_data1, np_data2, np_data3, np_data4 def _reference_layer_norm_naive( @@ -159,23 +165,33 @@ def fn(x, norm_shape, w, b): return F.layer_norm(x, norm_shape, w, b) -def expect_backward(x, norm_shape, w, b): +def dygraph_fused_backward_withNone(x, norm_shape, w, b, y_g): paddle.disable_static() x.stop_gradient = False res = fn(x, norm_shape, w, b) - gradients = paddle.grad(res, x) + gradients = paddle.grad(res, x, y_g) return gradients +def dygraph_fused_backward(x, norm_shape, w, b, y_g): + paddle.disable_static() + x.stop_gradient = False + w.stop_gradient = False + b.stop_gradient = False + res = fn(x, norm_shape, w, b) + gradients = paddle.grad(res, [x, w, b], y_g) + return gradients[0], gradients[1], gradients[2] + + class TestCompositelayer_norm(unittest.TestCase): def setUp(self): - self.dtypes = ["float16", "float32"] + self.dtypes = ["float32"] self.n_shape = [[4], [64, 128], [64]] self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] self.shape2s = [[4], [64 * 128], [64]] self.shape3s = [[4], [64 * 128], [64]] - def cal_composite_backward(self, inputs, norm_shape, weight, bias): + def static_comp_forward(self, inputs, norm_shape, weight, bias, y_g): paddle.enable_static() core._set_prim_forward_enabled(True) startup_program = paddle.static.Program() @@ -188,9 +204,16 @@ def cal_composite_backward(self, inputs, norm_shape, weight, bias): w = paddle.static.data( 'w', shape=weight.shape, dtype=str(weight.dtype) ) + w.stop_gradient = False b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + b.stop_gradient = False + y = fn(x, norm_shape, w, b) + y_grad = paddle.static.data( + 'y_grad', shape=y_g.shape, dtype=str(y_g.dtype) + ) + blocks = main_program.blocks fwd_ops = [op.type for op in blocks[0].ops] @@ -203,10 +226,10 @@ def cal_composite_backward(self, inputs, norm_shape, weight, bias): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops_new) - z = paddle.static.gradients([y], x) + z = paddle.static.gradients([y], [x, w, b], y_grad) + fwd_ops_grad = [op.type for op in blocks[0].ops] # Ensure that layer_norm_grad not in grad block - self.assertTrue('layer_norm_grad' not in fwd_ops_grad) exe = paddle.static.Executor() @@ -217,14 +240,17 @@ def cal_composite_backward(self, inputs, norm_shape, weight, bias): 'x': inputs, 'w': weight, 'b': bias, + 'y_grad': y_g, }, - fetch_list=[z], + fetch_list=z, ) paddle.disable_static() core._set_prim_forward_enabled(False) return res - def cal2_composite_backward(self, inputs, norm_shape, weight, bias): + def static_comp_forward_withNone( + self, inputs, norm_shape, weight, bias, y_g + ): paddle.enable_static() core._set_prim_forward_enabled(True) startup_program = paddle.static.Program() @@ -233,7 +259,9 @@ def cal2_composite_backward(self, inputs, norm_shape, weight, bias): x = paddle.static.data( 'x', shape=inputs.shape, dtype=str(inputs.dtype) ) - + y_grad = paddle.static.data( + 'y_grad', shape=y_g.shape, dtype=str(y_g.dtype) + ) x.stop_gradient = False y = fn(x, norm_shape, weight, bias) @@ -249,10 +277,9 @@ def cal2_composite_backward(self, inputs, norm_shape, weight, bias): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops_new) - z = paddle.static.gradients([y], x) + z = paddle.static.gradients([y], x, y_grad) fwd_ops_grad = [op.type for op in blocks[0].ops] # Ensure that layer_norm_grad not in grad block - self.assertTrue('layer_norm_grad' not in fwd_ops_grad) exe = paddle.static.Executor() @@ -261,35 +288,103 @@ def cal2_composite_backward(self, inputs, norm_shape, weight, bias): main_program, feed={ 'x': inputs, + 'y_grad': y_g, }, - fetch_list=[z], + fetch_list=z, ) paddle.disable_static() core._set_prim_forward_enabled(False) return res - def compare_backward(self): - x, w, b = generate_data( + # to_pirm after gradient can call comp_layer_norm_grad + def static_comp_forward_and_backward( + self, inputs, norm_shape, weight, bias, y_g + ): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + w.stop_gradient = False + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + b.stop_gradient = False + + y_grad = paddle.static.data( + 'y_grad', shape=y_g.shape, dtype=str(y_g.dtype) + ) + + y = fn(x, norm_shape, w, b) + + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + z = paddle.static.gradients([y], [x, w, b], y_grad) + + primapi.to_prim(blocks) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + 'w': weight, + 'b': bias, + 'y_grad': y_g, + }, + fetch_list=z, + ) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_comp_forward(self): + x, w, b, y_g = generate_data( attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype ) n_shape = attrs.n_shape x_p = paddle.to_tensor(x) w_p = paddle.to_tensor(w) b_p = paddle.to_tensor(b) + y_g_p = paddle.to_tensor(y_g) - expect = expect_backward(x_p, n_shape, w_p, b_p)[0].numpy() - actual = self.cal_composite_backward(x, n_shape, w, b)[0] + expect = dygraph_fused_backward(x_p, n_shape, w_p, b_p, y_g_p) + actual_fwd = self.static_comp_forward(x, n_shape, w, b, y_g) + actual_all = self.static_comp_forward_and_backward( + x, n_shape, w, b, y_g + ) - assert expect.dtype == actual.dtype + assert expect[0].numpy().dtype == actual_fwd[0].dtype np.testing.assert_allclose( - expect, - actual, + expect[0].numpy(), + actual_fwd[0], rtol=attrs.get_rtol("backward"), atol=attrs.get_atol("backward"), ) - expect_2 = expect_backward(x_p, n_shape, None, None)[0].numpy() - actual_2 = self.cal2_composite_backward(x, n_shape, None, None)[0] + np.testing.assert_allclose( + actual_fwd[0], + actual_all[0], + rtol=TOLERANCE_COMP_GRAD[attrs.dtype]['rtol'], + atol=TOLERANCE_COMP_GRAD[attrs.dtype]['atol'], + ) + + expect_2 = dygraph_fused_backward_withNone( + x_p, n_shape, None, None, y_g_p + )[0].numpy() + actual_2 = self.static_comp_forward_withNone( + x, n_shape, None, None, y_g + )[0] assert expect_2.dtype == actual_2.dtype np.testing.assert_allclose( expect_2, @@ -311,23 +406,23 @@ def test_backward(self): self.shape2s[t], self.shape3s[t], ) - self.compare_backward() + self.compare_comp_forward() class TestCompositelayer_normPrimBackward(unittest.TestCase): def setUp(self): core._set_prim_backward_enabled(True) - self.dtypes = ["float16", "float32"] + self.dtypes = ["float32"] self.n_shape = [[4], [64, 128], [64]] self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] self.shape2s = [[4], [64 * 128], [64]] self.shape3s = [[4], [64 * 128], [64]] - def cal_composite_backward(self, inputs, norm_shape, weight, bias): + def static_comp_forward_and_backward( + self, inputs, norm_shape, weight, bias + ): paddle.enable_static() core._set_prim_all_enabled(True) - core._add_skip_comp_ops("sqrt") - # TODO(Ruting) delete this after modify sqrt startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -360,11 +455,11 @@ def cal_composite_backward(self, inputs, norm_shape, weight, bias): core._set_prim_all_enabled(False) return res - def cal2_composite_backward(self, inputs, norm_shape, weight, bias): + def static_comp_forward_and_backward_withNone( + self, inputs, norm_shape, weight, bias + ): paddle.enable_static() core._set_prim_all_enabled(True) - core._add_skip_comp_ops("sqrt") - # TODO(Ruting) delete this after modify sqrt startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -392,16 +487,19 @@ def cal2_composite_backward(self, inputs, norm_shape, weight, bias): return res def compare_backward(self): - x, w, b = generate_data( + x, w, b, y_g = generate_data( attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype ) n_shape = attrs.n_shape x_p = paddle.to_tensor(x) w_p = paddle.to_tensor(w) b_p = paddle.to_tensor(b) + y_g_p = paddle.to_tensor(y_g) - expect = expect_backward(x_p, n_shape, w_p, b_p)[0].numpy() - actual = self.cal_composite_backward(x, n_shape, w, b)[0] + expect = dygraph_fused_backward_withNone(x_p, n_shape, w_p, b_p, y_g_p)[ + 0 + ].numpy() + actual = self.static_comp_forward_and_backward(x, n_shape, w, b)[0] assert expect.dtype == actual.dtype np.testing.assert_allclose( @@ -411,8 +509,12 @@ def compare_backward(self): atol=attrs.get_rtol("prim_backward"), ) - expect_2 = expect_backward(x_p, n_shape, None, None)[0].numpy() - actual_2 = self.cal2_composite_backward(x, n_shape, None, None)[0] + expect_2 = dygraph_fused_backward_withNone( + x_p, n_shape, None, None, y_g_p + )[0].numpy() + actual_2 = self.static_comp_forward_and_backward_withNone( + x, n_shape, None, None + )[0] assert expect_2.dtype == actual_2.dtype np.testing.assert_allclose( expect_2, @@ -457,7 +559,7 @@ def setUp(self): [64 * 128], ] - def cal_composite_backward(self, inputs, norm_shape, weight, bias, y_grad): + def static_comp_forward(self, inputs, norm_shape, weight, bias, y_grad): paddle.enable_static() core._set_prim_forward_enabled(True) startup_program = paddle.static.Program() @@ -509,13 +611,11 @@ def cal_composite_backward(self, inputs, norm_shape, weight, bias, y_grad): core._set_prim_forward_enabled(False) return res[0], res[1] - def cal_composite_backward_prim( + def static_comp_forward_prim( self, inputs, norm_shape, weight, bias, y_grad ): paddle.enable_static() core._set_prim_all_enabled(True) - core._add_skip_comp_ops("sqrt") - # TODO(Ruting) delete this after modify sqrt startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -548,16 +648,16 @@ def cal_composite_backward_prim( return res[0], res[1] def compare_backward(self): - x, w, b = generate_data( + x, w, b, y_grad = generate_data( attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype ) - y_grad = np.ones_like(x) + n_shape = attrs.n_shape - composite1, composite2 = self.cal_composite_backward( + composite1, composite2 = self.static_comp_forward( x, n_shape, w, b, y_grad ) - composite_p1, composite_p2 = self.cal_composite_backward_prim( + composite_p1, composite_p2 = self.static_comp_forward_prim( x, n_shape, w, b, y_grad ) diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py index 8bd89f48337a93..e455f9f11fb286 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py @@ -120,16 +120,17 @@ def test_prim(self): np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-1) @unittest.skipIf( - not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" + not paddle.is_compiled_with_cinn(), "paddle is not compiled with CINN" ) def test_cinn(self): dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True) np.testing.assert_allclose(self.dy2st, dy2st_cinn, rtol=1e-6) @unittest.skipIf( - not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" + not paddle.is_compiled_with_cinn(), "paddle is not compiled with CINN" ) def test_prim_cinn(self): + core._add_skip_comp_ops("layer_norm") dy2st_prim_cinn = train( to_static=True, enable_prim=True, enable_cinn=True ) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 02a88b155bba81..06c447d2cd627a 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -168,6 +168,7 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): variance = reshape(variance, [-1]) if is_amp: out = cast(out, "float16") + return out, mean_, variance @@ -301,6 +302,8 @@ def stack_composite(x, axis): def flatten_contiguous_range_composite(x, start_axis, stop_axis): """ define composite rule of op flatten, flatten_contiguous_range -> flatten. + + xshape is the dim with 0 added to the front of x, keep the shape information of x to calculate the grad. CINN doesn't need xshape for backward pass, return none instead of xshape. shape_out is the parameter of reshape, get from start_axis and stop_axis. out = reshape(x, shape=shape_out), xshape