diff --git a/paddle/phi/kernels/reduce_kernel_impl.cc b/paddle/phi/kernels/reduce_kernel_impl.cc index 000cb99034c26a..93192480999039 100644 --- a/paddle/phi/kernels/reduce_kernel_impl.cc +++ b/paddle/phi/kernels/reduce_kernel_impl.cc @@ -20,10 +20,16 @@ namespace phi { // oneDNN's reduction kernel is optimized only for reducing throughout the // most outer dims, so in case of another type of reduction, it would be // better to fallback to native implementation -inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx) { +inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx, + const bool mean_op) { const DenseTensor& x = ctx->InputAt(0); - const TensorRef& dims_tmp = ctx->AttrAt(0); - IntArray dims_array = IntArray(*dims_tmp.Get()); + IntArray dims_array; + if (mean_op) { + dims_array = ctx->AttrAt(0); + } else { + const TensorRef& dims_tmp = ctx->AttrAt(0); + dims_array = IntArray(*dims_tmp.Get()); + } int ndims = x.dims().size(); const bool reduce_all = recompute_reduce_all(x, dims_array); auto dims = dims_array.GetData(); @@ -53,7 +59,15 @@ inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx) { bool ReduceCheckIfOneDNNSupport(const KernelContext* ctx) { if (ctx->InputAt(0).dims().size() > 5 || - !HasOptimizedOneDNNKernel(ctx)) { + !HasOptimizedOneDNNKernel(ctx, false)) { + return false; + } + return true; +} + +bool ReduceMeanCheckIfOneDNNSupport(const KernelContext* ctx) { + if (ctx->InputAt(0).dims().size() > 5 || + !HasOptimizedOneDNNKernel(ctx, true)) { return false; } return true; diff --git a/paddle/phi/kernels/reduce_kernel_impl.h b/paddle/phi/kernels/reduce_kernel_impl.h index aef4f57ddbdcff..e117f6ab335dd6 100644 --- a/paddle/phi/kernels/reduce_kernel_impl.h +++ b/paddle/phi/kernels/reduce_kernel_impl.h @@ -21,4 +21,6 @@ bool ReduceCheckIfOneDNNSupport(const KernelContext* ctx); bool ReduceGradCheckIfOneDNNSupport(const KernelContext* ctx); +bool ReduceMeanCheckIfOneDNNSupport(const KernelContext* ctx); + } // namespace phi diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index 16b3abf0e29319..a657e7ba8c01d0 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -67,7 +67,7 @@ PD_REGISTER_KERNEL(mean, KPS, ALL_LAYOUT, phi::MeanKernel, float) {} #if defined(PADDLE_WITH_DNNL) PD_REGISTER_KERNEL( mean, OneDNN, ONEDNN, phi::MeanKernel, float, phi::dtype::bfloat16) { - kernel->check_if_onednn_kernel_support_ = phi::ReduceCheckIfOneDNNSupport; + kernel->check_if_onednn_kernel_support_ = phi::ReduceMeanCheckIfOneDNNSupport; } #endif diff --git a/test/ir/pir/fused_pass/onednn/test_placement_pass_mean_op.py b/test/ir/pir/fused_pass/onednn/test_placement_pass_mean_op.py new file mode 100644 index 00000000000000..6443a60c331f97 --- /dev/null +++ b/test/ir/pir/fused_pass/onednn/test_placement_pass_mean_op.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle + +paddle.enable_static() + + +class TestMeanPlacementPass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 2, 5, 5], dtype='float32' + ) + mean = paddle.mean(x) + out = paddle.assign(mean) + self.pass_attr_list = [{'onednn_placement_pass': {}}] + + self.feeds = { + "x": np.random.random((5, 2, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.mean": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +if __name__ == "__main__": + unittest.main()