Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void FetchTensors(const std::vector<std::string>& job_fetch_names,
&(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col]));
if (src.IsInitialized()) {
TensorCopy(src, platform::CPUPlace(), dst);
dst->set_lod(src.lod());
} else {
VLOG(6) << "Found " << var_name
<< " is not initialized and skip TensorCopy.";
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ if (WITH_GPU OR WITH_ROCM)
register_cu_kernel(class_center_sample_op SRCS class_center_sample_op.cu DEPS ${OP_HEADER_DEPS})
endif()

if (WITH_MKLDNN)
register_mkldnn_kernel(layer_norm_op SRCS layer_norm_mkldnn_op.cc DEPS ${OP_HEADER_DEPS})
endif()

if (WITH_GPU OR WITH_ROCM)
op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS})
elseif (WITH_XPU_KP)
Expand Down
150 changes: 0 additions & 150 deletions paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc

This file was deleted.

2 changes: 0 additions & 2 deletions paddle/fluid/operators/unity_build_rule.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ register_unity_group(
l1_norm_op.cc
label_smooth_op.cc
generated_op
mkldnn/layer_norm_mkldnn_op.cc
mkldnn/layer_norm_mkldnn_op.cc
linspace_op.cc
load_combine_op.cc
load_op.cc)
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@

# - op : fused_transpose

# - op : fusion_gru
- op : fusion_gru
extra_args : str mkldnn_data_type="float32", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0f}

# - op : fusion_lstm

Expand All @@ -130,7 +131,8 @@

- op : hardswish_grad

# - op : layer_norm
- op : layer_norm
extra_args : str mkldnn_data_type="float32", bool is_test=false

- op : leaky_relu

Expand All @@ -146,13 +148,13 @@
extra_args : bool is_test=false
data_format_tensors : x, out, mid_out, out_grad

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
# layout_transform :
# arg_name: cur_paddle_data_layout
# tensors: x, y
- op : matmul
extra_args : str mkldnn_data_type="float32"
data_format_tensors : x, y

# - op : matmul_grad
- op : matmul_grad
extra_args : str mkldnn_data_type="float32"
data_format_tensors : x, y, out_grad

# - op : matmul_with_flatten

Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,8 @@
batched_input : BatchedInput
batched_out : BatchedOut
hidden : Hidden
attrs :
{scale_data : Scale_data, shift_data : Shift_data, scale_weights : Scale_weights}
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", float Scale_data = 1.0f, float Shift_data = 0.0f, 'float[] Scale_weights = {1.0f}']

Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void OneDNN2PaddleLayout(const Context& dev_ctx,
VLOG(4) << "src_layout: " << src_layout << ", tmp_layout: " << tmp_layout;

if (src_layout != DataLayout::ONEDNN || !x.storage_properties_initialized()) {
if (!x.IsInitialized()) {
out->Resize(x.dims());
out->set_layout(tmp_layout);
return;
}
out->ShareDataWith(x);
out->ShareInplaceVersionCounterWith(x);
out->set_layout(static_cast<DataLayout>(tmp_layout));
Expand Down
147 changes: 147 additions & 0 deletions paddle/phi/kernels/onednn/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/layer_norm_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
class LayerNormOneDNNHandler
: public phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward> {
public:
LayerNormOneDNNHandler(const std::vector<int64_t>& dims,
const float& epsilon,
const dnnl::normalization_flags& flags,
const bool& is_test,
const phi::DenseTensor* x,
const dnnl::engine engine,
Place cpu_place)
: phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::layer_normalization_forward>(
engine, cpu_place) {
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;

this->AcquireForwardPrimitiveDescriptor(
fwd_prop_kind, x->mem_desc(), x->mem_desc(), epsilon, flags);
}

std::tuple<std::shared_ptr<dnnl::memory>, std::shared_ptr<dnnl::memory>>
AcquireScaleShiftMemory(const phi::DenseTensor* scale,
const phi::DenseTensor* shift) {
auto scale_memory = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<float>(scale->data<float>()));
auto shift_memory = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<float>(shift->data<float>()));

return std::make_tuple(scale_memory, shift_memory);
}

std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor* mean) {
float* mean_data = mean->mutable_data<float>(
this->place_, this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data);
}

std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
phi::DenseTensor* variance) {
float* variance_data = variance->mutable_data<float>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data);
}
};

template <typename T, typename Context>
void LayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale_opt,
const paddle::optional<DenseTensor>& bias_opt,
float epsilon,
int begin_norm_axis,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;

const auto& onednn_engine = dev_ctx.GetEngine();

auto src_tz = common::vectorize(x.dims());
PADDLE_ENFORCE_EQ(begin_norm_axis,
(src_tz.size() - 1),
phi::errors::InvalidArgument(
"MKL-DNN Layer Norm supports only last logical "
"axis:%d as begin_norm_axis.",
(src_tz.size() - 1)));

const bool with_scaleshift = (scale_opt && bias_opt);
dnnl::normalization_flags flags{};

if (with_scaleshift) {
flags |= dnnl::normalization_flags::use_scale |
dnnl::normalization_flags::use_shift;
}

LayerNormOneDNNHandler<T> handler(
src_tz, epsilon, flags, is_test, &x, onednn_engine, dev_ctx.GetPlace());

auto src_memory = handler.AcquireSrcMemory(&x);
auto dst_memory = handler.AcquireDstMemory(y);

auto layer_norm_p = handler.AcquireForwardPrimitive();

auto& astream = phi::OneDNNContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}};

if (!is_test) {
auto mean_memory = handler.AcquireMeanMemory(mean);
auto variance_memory = handler.AcquireVarianceMemory(var);

args.insert({DNNL_ARG_MEAN, *mean_memory});
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
}

if (with_scaleshift) {
auto scaleshift_mems = handler.AcquireScaleShiftMemory(scale_opt.get_ptr(),
bias_opt.get_ptr());
args.insert({DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))});
args.insert({DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))});
}

layer_norm_p->execute(astream, args);
astream.wait();

y->set_mem_desc(dst_memory->get_desc());
}
} // namespace phi

PD_REGISTER_KERNEL(layer_norm,
OneDNN,
ONEDNN,
phi::LayerNormKernel,
float,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
4 changes: 3 additions & 1 deletion test/legacy_test/test_fusion_gru_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def setUp(self):
def test_check_output(self):
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output(check_dygraph=False)
self.check_output(
check_dygraph=False, check_pir_onednn=self.check_pir_onednn
)


class TestFusionGRUOpNoInitial(TestFusionGRUOp):
Expand Down
Loading