diff --git a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc index 99829de387c321..f82350ec6d103f 100644 --- a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc +++ b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc @@ -115,6 +115,7 @@ void FetchTensors(const std::vector& job_fetch_names, &(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col])); if (src.IsInitialized()) { TensorCopy(src, platform::CPUPlace(), dst); + dst->set_lod(src.lod()); } else { VLOG(6) << "Found " << var_name << " is not initialized and skip TensorCopy."; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 423638426f7fd8..5d03c833a87c7e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -106,10 +106,6 @@ if (WITH_GPU OR WITH_ROCM) register_cu_kernel(class_center_sample_op SRCS class_center_sample_op.cu DEPS ${OP_HEADER_DEPS}) endif() -if (WITH_MKLDNN) - register_mkldnn_kernel(layer_norm_op SRCS layer_norm_mkldnn_op.cc DEPS ${OP_HEADER_DEPS}) -endif() - if (WITH_GPU OR WITH_ROCM) op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS}) elseif (WITH_XPU_KP) diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc deleted file mode 100644 index d2b715a5f56e6a..00000000000000 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" - -#include "paddle/phi/backends/onednn/onednn_reuse.h" -#include "paddle/phi/common/data_type.h" - -namespace paddle { -namespace operators { - -template -class LayerNormOneDNNHandler - : public phi::funcs:: - OneDNNHandlerNoCachingT { - public: - LayerNormOneDNNHandler(const std::vector& dims, - const float& epsilon, - const dnnl::normalization_flags& flags, - const bool& is_test, - const phi::DenseTensor* x, - const dnnl::engine engine, - platform::Place cpu_place) - : phi::funcs::OneDNNHandlerNoCachingT( - engine, cpu_place) { - const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference - : dnnl::prop_kind::forward_training; - - this->AcquireForwardPrimitiveDescriptor( - fwd_prop_kind, x->mem_desc(), x->mem_desc(), epsilon, flags); - } - - std::tuple, std::shared_ptr> - AcquireScaleShiftMemory(const phi::DenseTensor* scale, - const phi::DenseTensor* shift) { - auto scale_memory = this->AcquireMemoryFromPrimitive( - this->fwd_pd_->weights_desc(), - phi::funcs::to_void_cast(scale->data())); - auto shift_memory = this->AcquireMemoryFromPrimitive( - this->fwd_pd_->weights_desc(), - phi::funcs::to_void_cast(shift->data())); - - return std::make_tuple(scale_memory, shift_memory); - } - - std::shared_ptr AcquireMeanMemory(phi::DenseTensor* mean) { - float* mean_data = mean->mutable_data( - this->place_, this->fwd_pd_->mean_desc().get_size()); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), - mean_data); - } - - std::shared_ptr AcquireVarianceMemory( - phi::DenseTensor* variance) { - float* variance_data = variance->mutable_data( - this->place_, this->fwd_pd_->variance_desc().get_size()); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), - variance_data); - } -}; - -template -class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Y"); - auto* scale = ctx.Input("Scale"); - auto* bias = ctx.Input("Bias"); - - const float epsilon = ctx.Attr("epsilon"); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - const bool is_test = ctx.Attr("is_test"); - - auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto src_tz = common::vectorize(x->dims()); - PADDLE_ENFORCE_EQ(begin_norm_axis, - (src_tz.size() - 1), - platform::errors::InvalidArgument( - "MKL-DNN Layer Norm supports only last logical " - "axis:%d as begin_norm_axis.", - (src_tz.size() - 1))); - - const bool with_scaleshift = (scale && bias); - dnnl::normalization_flags flags{}; - - if (with_scaleshift) { - flags |= dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift; - } - - LayerNormOneDNNHandler handler( - src_tz, epsilon, flags, is_test, x, onednn_engine, ctx.GetPlace()); - - auto src_memory = handler.AcquireSrcMemory(x); - auto dst_memory = handler.AcquireDstMemory(out); - - auto layer_norm_p = handler.AcquireForwardPrimitive(); - - auto& astream = phi::OneDNNContext::tls().get_stream(); - std::unordered_map args = {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_DST, *dst_memory}}; - - if (!is_test) { - auto* mean = ctx.Output("Mean"); - auto* var = ctx.Output("Variance"); - - auto mean_memory = handler.AcquireMeanMemory(mean); - auto variance_memory = handler.AcquireVarianceMemory(var); - - args.insert({DNNL_ARG_MEAN, *mean_memory}); - args.insert({DNNL_ARG_VARIANCE, *variance_memory}); - } - - if (with_scaleshift) { - auto scaleshift_mems = handler.AcquireScaleShiftMemory(scale, bias); - args.insert({DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))}); - args.insert({DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))}); - } - - layer_norm_p->execute(astream, args); - astream.wait(); - - out->set_mem_desc(dst_memory->get_desc()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(layer_norm, - MKLDNN, - ::phi::CPUPlace, - ops::LayerNormMKLDNNOpKernel, - ops::LayerNormMKLDNNOpKernel); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 2e1b6f86d6370c..07136f7bd4f310 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -131,8 +131,6 @@ register_unity_group( l1_norm_op.cc label_smooth_op.cc generated_op - mkldnn/layer_norm_mkldnn_op.cc - mkldnn/layer_norm_mkldnn_op.cc linspace_op.cc load_combine_op.cc load_op.cc) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml index d9d8fe69990249..4194cdb1366e05 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -114,7 +114,8 @@ # - op : fused_transpose -# - op : fusion_gru +- op : fusion_gru + extra_args : str mkldnn_data_type="float32", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0f} # - op : fusion_lstm @@ -130,7 +131,8 @@ - op : hardswish_grad -# - op : layer_norm +- op : layer_norm + extra_args : str mkldnn_data_type="float32", bool is_test=false - op : leaky_relu @@ -146,13 +148,13 @@ extra_args : bool is_test=false data_format_tensors : x, out, mid_out, out_grad -# - op : matmul -# extra_args : str mkldnn_data_type="float32" -# layout_transform : -# arg_name: cur_paddle_data_layout -# tensors: x, y +- op : matmul + extra_args : str mkldnn_data_type="float32" + data_format_tensors : x, y -# - op : matmul_grad +- op : matmul_grad + extra_args : str mkldnn_data_type="float32" + data_format_tensors : x, y, out_grad # - op : matmul_with_flatten diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index f5f94f204d2f33..d5e003cf304265 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -50,7 +50,7 @@ #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" #ifdef PADDLE_WITH_DNNL -#include "build/paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" #include "paddle/fluid/pir/dialect/operator/trait/onednn.h" diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 490c43ace3c2c2..ae6a3106ddb26f 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1433,6 +1433,8 @@ batched_input : BatchedInput batched_out : BatchedOut hidden : Hidden + attrs : + {scale_data : Scale_data, shift_data : Shift_data, scale_weights : Scale_weights} extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", float Scale_data = 1.0f, float Shift_data = 0.0f, 'float[] Scale_weights = {1.0f}'] diff --git a/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc index f9324ca8b3e5f6..f9257ebcddc36f 100644 --- a/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc +++ b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc @@ -64,6 +64,11 @@ void OneDNN2PaddleLayout(const Context& dev_ctx, VLOG(4) << "src_layout: " << src_layout << ", tmp_layout: " << tmp_layout; if (src_layout != DataLayout::ONEDNN || !x.storage_properties_initialized()) { + if (!x.IsInitialized()) { + out->Resize(x.dims()); + out->set_layout(tmp_layout); + return; + } out->ShareDataWith(x); out->ShareInplaceVersionCounterWith(x); out->set_layout(static_cast(tmp_layout)); diff --git a/paddle/phi/kernels/onednn/layer_norm_kernel.cc b/paddle/phi/kernels/onednn/layer_norm_kernel.cc new file mode 100644 index 00000000000000..02aa5298b23261 --- /dev/null +++ b/paddle/phi/kernels/onednn/layer_norm_kernel.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/layer_norm_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +class LayerNormOneDNNHandler + : public phi::funcs:: + OneDNNHandlerNoCachingT { + public: + LayerNormOneDNNHandler(const std::vector& dims, + const float& epsilon, + const dnnl::normalization_flags& flags, + const bool& is_test, + const phi::DenseTensor* x, + const dnnl::engine engine, + Place cpu_place) + : phi::funcs::OneDNNHandlerNoCachingT( + engine, cpu_place) { + const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference + : dnnl::prop_kind::forward_training; + + this->AcquireForwardPrimitiveDescriptor( + fwd_prop_kind, x->mem_desc(), x->mem_desc(), epsilon, flags); + } + + std::tuple, std::shared_ptr> + AcquireScaleShiftMemory(const phi::DenseTensor* scale, + const phi::DenseTensor* shift) { + auto scale_memory = this->AcquireMemoryFromPrimitive( + this->fwd_pd_->weights_desc(), + phi::funcs::to_void_cast(scale->data())); + auto shift_memory = this->AcquireMemoryFromPrimitive( + this->fwd_pd_->weights_desc(), + phi::funcs::to_void_cast(shift->data())); + + return std::make_tuple(scale_memory, shift_memory); + } + + std::shared_ptr AcquireMeanMemory(const OneDNNContext& dev_ctx, + phi::DenseTensor* mean) { + float* mean_data = dev_ctx.template Alloc( + mean, this->fwd_pd_->mean_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + mean_data); + } + + std::shared_ptr AcquireVarianceMemory( + const OneDNNContext& dev_ctx, phi::DenseTensor* variance) { + float* variance_data = dev_ctx.template Alloc( + variance, this->fwd_pd_->variance_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + variance_data); + } +}; + +template +void LayerNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale_opt, + const paddle::optional& bias_opt, + float epsilon, + int begin_norm_axis, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + bool is_test = dev_ctx.HasDnnAttr("is_test") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test")) + : false; + + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto src_tz = common::vectorize(x.dims()); + PADDLE_ENFORCE_EQ(begin_norm_axis, + (src_tz.size() - 1), + phi::errors::InvalidArgument( + "MKL-DNN Layer Norm supports only last logical " + "axis:%d as begin_norm_axis.", + (src_tz.size() - 1))); + + const bool with_scaleshift = (scale_opt && bias_opt); + dnnl::normalization_flags flags{}; + + if (with_scaleshift) { + flags |= dnnl::normalization_flags::use_scale | + dnnl::normalization_flags::use_shift; + } + + LayerNormOneDNNHandler handler( + src_tz, epsilon, flags, is_test, &x, onednn_engine, dev_ctx.GetPlace()); + + auto src_memory = handler.AcquireSrcMemory(&x); + auto dst_memory = handler.AcquireDstMemory(y); + + auto layer_norm_p = handler.AcquireForwardPrimitive(); + + auto& astream = phi::OneDNNContext::tls().get_stream(); + std::unordered_map args = {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_DST, *dst_memory}}; + + if (!is_test) { + auto mean_memory = handler.AcquireMeanMemory(dev_ctx, mean); + auto variance_memory = handler.AcquireVarianceMemory(dev_ctx, var); + + args.insert({DNNL_ARG_MEAN, *mean_memory}); + args.insert({DNNL_ARG_VARIANCE, *variance_memory}); + } + + if (with_scaleshift) { + auto scaleshift_mems = handler.AcquireScaleShiftMemory(scale_opt.get_ptr(), + bias_opt.get_ptr()); + args.insert({DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))}); + args.insert({DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))}); + } + + layer_norm_p->execute(astream, args); + astream.wait(); + + y->set_mem_desc(dst_memory->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(layer_norm, + OneDNN, + ONEDNN, + phi::LayerNormKernel, + float, + phi::dtype::bfloat16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/test/legacy_test/test_fusion_gru_op.py b/test/legacy_test/test_fusion_gru_op.py index a86fd9b1f7b7ce..f36a1fd4a72cb9 100644 --- a/test/legacy_test/test_fusion_gru_op.py +++ b/test/legacy_test/test_fusion_gru_op.py @@ -114,7 +114,9 @@ def setUp(self): def test_check_output(self): for use_seq in {True, False}: self.attrs['use_seq'] = use_seq - self.check_output(check_dygraph=False) + self.check_output( + check_dygraph=False, check_pir_onednn=self.check_pir_onednn + ) class TestFusionGRUOpNoInitial(TestFusionGRUOp): diff --git a/test/mkldnn/test_fusion_gru_bf16_mkldnn_op.py b/test/mkldnn/test_fusion_gru_bf16_mkldnn_op.py index 93b141f9eefca5..ae44798dce4eb3 100644 --- a/test/mkldnn/test_fusion_gru_bf16_mkldnn_op.py +++ b/test/mkldnn/test_fusion_gru_bf16_mkldnn_op.py @@ -32,7 +32,9 @@ def set_confs(self): def test_check_output(self): for use_seq in {True, False}: self.attrs['use_seq'] = use_seq - self.check_output(check_dygraph=False) + self.check_output( + check_dygraph=False, check_pir_onednn=self.check_pir_onednn + ) def setUp(self): self.op_type = "fusion_gru" diff --git a/test/mkldnn/test_fusion_gru_int8_mkldnn_op.py b/test/mkldnn/test_fusion_gru_int8_mkldnn_op.py index 352ce5bc5db5fb..b7b775de2581e3 100644 --- a/test/mkldnn/test_fusion_gru_int8_mkldnn_op.py +++ b/test/mkldnn/test_fusion_gru_int8_mkldnn_op.py @@ -150,7 +150,11 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_dygraph=False, atol=self.error_margin) + self.check_output( + check_dygraph=False, + atol=self.error_margin, + check_pir_onednn=self.check_pir_onednn, + ) class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUINT8MKLDNNOp): diff --git a/test/mkldnn/test_fusion_gru_mkldnn_op.py b/test/mkldnn/test_fusion_gru_mkldnn_op.py index 9e619b73ff1793..112c7c1389dc02 100644 --- a/test/mkldnn/test_fusion_gru_mkldnn_op.py +++ b/test/mkldnn/test_fusion_gru_mkldnn_op.py @@ -20,30 +20,35 @@ class TestFusionGRUMKLDNNOp(TestFusionGRUOp): def set_confs(self): self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpNoInitial(TestFusionGRUOp): def set_confs(self): self.with_h0 = False self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpNoBias(TestFusionGRUOp): def set_confs(self): self.with_bias = False self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpReverse(TestFusionGRUOp): def set_confs(self): self.is_reverse = True self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpOriginMode(TestFusionGRUOp): def set_confs(self): self.origin_mode = True self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpMD1(TestFusionGRUOp): @@ -51,6 +56,7 @@ def set_confs(self): self.M = 36 self.D = 8 self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpMD2(TestFusionGRUOp): @@ -58,6 +64,7 @@ def set_confs(self): self.M = 8 self.D = 8 self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpMD3(TestFusionGRUOp): @@ -65,6 +72,7 @@ def set_confs(self): self.M = 17 self.D = 15 self.use_mkldnn = True + self.check_pir_onednn = True class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp): @@ -72,6 +80,7 @@ def set_confs(self): self.lod = [[3]] self.D = 16 self.use_mkldnn = True + self.check_pir_onednn = True if __name__ == "__main__": diff --git a/test/mkldnn/test_layer_norm_bf16_mkldnn_op.py b/test/mkldnn/test_layer_norm_bf16_mkldnn_op.py index a67dd64a4fbd4f..96dd1f818b1239 100644 --- a/test/mkldnn/test_layer_norm_bf16_mkldnn_op.py +++ b/test/mkldnn/test_layer_norm_bf16_mkldnn_op.py @@ -23,6 +23,7 @@ TestLayerNormMKLDNNOp, _reference_layer_norm_naive, ) +from utils import pir_executor_guard from paddle import base, enable_static from paddle.base import core @@ -133,9 +134,10 @@ def check_forward( self.__assert_close(variance, out[2], "variance", 1e-3) def test_check_forward_with_is_test(self): - self.check_forward( - shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True - ) + with pir_executor_guard(): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True + ) # TODO (jczaja): Enable those to test when enabling training using bf16 def test_check_forward_with_scale_and_bias(self): diff --git a/test/mkldnn/test_layer_norm_mkldnn_op.py b/test/mkldnn/test_layer_norm_mkldnn_op.py index c225469e71cc80..d2ba6062ffe6bb 100644 --- a/test/mkldnn/test_layer_norm_mkldnn_op.py +++ b/test/mkldnn/test_layer_norm_mkldnn_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTestTool, _set_use_system_allocator +from utils import pir_executor_guard from paddle import base, enable_static from paddle.base import core @@ -144,20 +145,24 @@ def check_forward( @OpTestTool.skip_if_not_cpu_bf16() def test_check_forward_non_last_begin_norm_axis(self): - self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=2) + with pir_executor_guard(): + self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=2) def test_check_forward_with_scale_and_bias(self): - self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3) + with pir_executor_guard(): + self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3) def test_check_forward_without_scale_and_bias(self): - self.check_forward( - shape=[2, 3, 4, 5], begin_norm_axis=3, with_scale_bias=False - ) + with pir_executor_guard(): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_scale_bias=False + ) def test_check_forward_with_is_test(self): - self.check_forward( - shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True - ) + with pir_executor_guard(): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True + ) if __name__ == "__main__": diff --git a/test/mkldnn/test_matmul_v2_mkldnn_op.py b/test/mkldnn/test_matmul_v2_mkldnn_op.py index 8c9fb2e0928354..42c592cca9bdf0 100644 --- a/test/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/test/mkldnn/test_matmul_v2_mkldnn_op.py @@ -83,10 +83,12 @@ def setUp(self): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_pir_onednn=True, check_dygraph=False) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad( + ['X', 'Y'], 'Out', check_pir_onednn=True, check_dygraph=False + ) class TestMatMulV2VectorXMatrixTransposeYOneDNNOp( @@ -313,7 +315,9 @@ def set_dtype_attr(self): self.attrs['mkldnn_data_type'] = "bfloat16" def test_check_output(self): - self.check_output_with_place(core.CPUPlace()) + self.check_output_with_place( + core.CPUPlace(), check_pir_onednn=True, check_dygraph=False + ) def test_check_grad(self): self.calculate_grads() @@ -323,6 +327,8 @@ def test_check_grad(self): "Out", user_defined_grads=[self.dx, self.dy], user_defined_grad_outputs=[convert_float_to_uint16(self.dout)], + check_pir_onednn=True, + check_dygraph=False, ) def matmul_grad(self, x, transpose_x, y, transpose_y):