Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions paddle/phi/kernels/reduce_kernel_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ namespace phi {
// oneDNN's reduction kernel is optimized only for reducing throughout the
// most outer dims, so in case of another type of reduction, it would be
// better to fallback to native implementation
inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx) {
inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx,
const bool mean_op) {
const DenseTensor& x = ctx->InputAt<phi::DenseTensor>(0);
const TensorRef& dims_tmp = ctx->AttrAt<TensorRef>(0);
IntArray dims_array = IntArray(*dims_tmp.Get());
IntArray dims_array;
if (mean_op) {
dims_array = ctx->AttrAt<IntArray>(0);
} else {
const TensorRef& dims_tmp = ctx->AttrAt<TensorRef>(0);
dims_array = IntArray(*dims_tmp.Get());
}
int ndims = x.dims().size();
const bool reduce_all = recompute_reduce_all(x, dims_array);
auto dims = dims_array.GetData();
Expand Down Expand Up @@ -53,7 +59,15 @@ inline bool HasOptimizedOneDNNKernel(const KernelContext* ctx) {

bool ReduceCheckIfOneDNNSupport(const KernelContext* ctx) {
if (ctx->InputAt<phi::DenseTensor>(0).dims().size() > 5 ||
!HasOptimizedOneDNNKernel(ctx)) {
!HasOptimizedOneDNNKernel(ctx, false)) {
return false;
}
return true;
}

bool ReduceMeanCheckIfOneDNNSupport(const KernelContext* ctx) {
if (ctx->InputAt<phi::DenseTensor>(0).dims().size() > 5 ||
!HasOptimizedOneDNNKernel(ctx, true)) {
return false;
}
return true;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/reduce_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ bool ReduceCheckIfOneDNNSupport(const KernelContext* ctx);

bool ReduceGradCheckIfOneDNNSupport(const KernelContext* ctx);

bool ReduceMeanCheckIfOneDNNSupport(const KernelContext* ctx);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/reduce_mean_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ PD_REGISTER_KERNEL(mean, KPS, ALL_LAYOUT, phi::MeanKernel, float) {}
#if defined(PADDLE_WITH_DNNL)
PD_REGISTER_KERNEL(
mean, OneDNN, ONEDNN, phi::MeanKernel, float, phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::ReduceCheckIfOneDNNSupport;
kernel->check_if_onednn_kernel_support_ = phi::ReduceMeanCheckIfOneDNNSupport;
}
#endif

Expand Down
60 changes: 60 additions & 0 deletions test/ir/pir/fused_pass/onednn/test_placement_pass_mean_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
from pass_test import PassTest

import paddle

paddle.enable_static()


class TestMeanPlacementPass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[5, 2, 5, 5], dtype='float32'
)
mean = paddle.mean(x)
out = paddle.assign(mean)
self.pass_attr_list = [{'onednn_placement_pass': {}}]

self.feeds = {
"x": np.random.random((5, 2, 5, 5)).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"onednn_op.mean": 1,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
self.places.append(paddle.CPUPlace())

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()