Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void FetchTensors(const std::vector<std::string>& job_fetch_names,
&(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col]));
if (src.IsInitialized()) {
TensorCopy(src, platform::CPUPlace(), dst);
dst->set_lod(src.lod());
} else {
VLOG(6) << "Found " << var_name
<< " is not initialized and skip TensorCopy.";
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ if (WITH_GPU OR WITH_ROCM)
register_cu_kernel(class_center_sample_op SRCS class_center_sample_op.cu DEPS ${OP_HEADER_DEPS})
endif()

if (WITH_MKLDNN)
register_mkldnn_kernel(layer_norm_op SRCS layer_norm_mkldnn_op.cc DEPS ${OP_HEADER_DEPS})
endif()

if (WITH_GPU OR WITH_ROCM)
op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS})
elseif (WITH_XPU_KP)
Expand Down
150 changes: 0 additions & 150 deletions paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc

This file was deleted.

2 changes: 0 additions & 2 deletions paddle/fluid/operators/unity_build_rule.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ register_unity_group(
l1_norm_op.cc
label_smooth_op.cc
generated_op
mkldnn/layer_norm_mkldnn_op.cc
mkldnn/layer_norm_mkldnn_op.cc
linspace_op.cc
load_combine_op.cc
load_op.cc)
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@

# - op : fused_transpose

# - op : fusion_gru
- op : fusion_gru
extra_args : str mkldnn_data_type="float32", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0f}

# - op : fusion_lstm

Expand All @@ -130,7 +131,8 @@

- op : hardswish_grad

# - op : layer_norm
- op : layer_norm
extra_args : str mkldnn_data_type="float32", bool is_test=false

- op : leaky_relu

Expand All @@ -146,13 +148,13 @@
extra_args : bool is_test=false
data_format_tensors : x, out, mid_out, out_grad

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
# layout_transform :
# arg_name: cur_paddle_data_layout
# tensors: x, y
- op : matmul
extra_args : str mkldnn_data_type="float32"
data_format_tensors : x, y

# - op : matmul_grad
- op : matmul_grad
extra_args : str mkldnn_data_type="float32"
data_format_tensors : x, y, out_grad

# - op : matmul_with_flatten

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "build/paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h"
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,8 @@
batched_input : BatchedInput
batched_out : BatchedOut
hidden : Hidden
attrs :
{scale_data : Scale_data, shift_data : Shift_data, scale_weights : Scale_weights}
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", float Scale_data = 1.0f, float Shift_data = 0.0f, 'float[] Scale_weights = {1.0f}']

Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void OneDNN2PaddleLayout(const Context& dev_ctx,
VLOG(4) << "src_layout: " << src_layout << ", tmp_layout: " << tmp_layout;

if (src_layout != DataLayout::ONEDNN || !x.storage_properties_initialized()) {
if (!x.IsInitialized()) {
out->Resize(x.dims());
out->set_layout(tmp_layout);
return;
}
out->ShareDataWith(x);
out->ShareInplaceVersionCounterWith(x);
out->set_layout(static_cast<DataLayout>(tmp_layout));
Expand Down
148 changes: 148 additions & 0 deletions paddle/phi/kernels/onednn/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/layer_norm_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
class LayerNormOneDNNHandler
: public phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward> {
public:
LayerNormOneDNNHandler(const std::vector<int64_t>& dims,
const float& epsilon,
const dnnl::normalization_flags& flags,
const bool& is_test,
const phi::DenseTensor* x,
const dnnl::engine engine,
Place cpu_place)
: phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::layer_normalization_forward>(
engine, cpu_place) {
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;

this->AcquireForwardPrimitiveDescriptor(
fwd_prop_kind, x->mem_desc(), x->mem_desc(), epsilon, flags);
}

std::tuple<std::shared_ptr<dnnl::memory>, std::shared_ptr<dnnl::memory>>
AcquireScaleShiftMemory(const phi::DenseTensor* scale,
const phi::DenseTensor* shift) {
auto scale_memory = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<float>(scale->data<float>()));
auto shift_memory = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<float>(shift->data<float>()));

return std::make_tuple(scale_memory, shift_memory);
}

std::shared_ptr<dnnl::memory> AcquireMeanMemory(const OneDNNContext& dev_ctx,
phi::DenseTensor* mean) {
float* mean_data = dev_ctx.template Alloc<float>(
mean, this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data);
}

std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
const OneDNNContext& dev_ctx, phi::DenseTensor* variance) {
float* variance_data = dev_ctx.template Alloc<float>(
variance, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data);
}
};

template <typename T, typename Context>
void LayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale_opt,
const paddle::optional<DenseTensor>& bias_opt,
float epsilon,
int begin_norm_axis,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;

const auto& onednn_engine = dev_ctx.GetEngine();

auto src_tz = common::vectorize(x.dims());
PADDLE_ENFORCE_EQ(begin_norm_axis,
(src_tz.size() - 1),
phi::errors::InvalidArgument(
"MKL-DNN Layer Norm supports only last logical "
"axis:%d as begin_norm_axis.",
(src_tz.size() - 1)));

const bool with_scaleshift = (scale_opt && bias_opt);
dnnl::normalization_flags flags{};

if (with_scaleshift) {
flags |= dnnl::normalization_flags::use_scale |
dnnl::normalization_flags::use_shift;
}

LayerNormOneDNNHandler<T> handler(
src_tz, epsilon, flags, is_test, &x, onednn_engine, dev_ctx.GetPlace());

auto src_memory = handler.AcquireSrcMemory(&x);
auto dst_memory = handler.AcquireDstMemory(y);

auto layer_norm_p = handler.AcquireForwardPrimitive();

auto& astream = phi::OneDNNContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}};

if (!is_test) {
auto mean_memory = handler.AcquireMeanMemory(dev_ctx, mean);
auto variance_memory = handler.AcquireVarianceMemory(dev_ctx, var);

args.insert({DNNL_ARG_MEAN, *mean_memory});
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
}

if (with_scaleshift) {
auto scaleshift_mems = handler.AcquireScaleShiftMemory(scale_opt.get_ptr(),
bias_opt.get_ptr());
args.insert({DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))});
args.insert({DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))});
}

layer_norm_p->execute(astream, args);
astream.wait();

y->set_mem_desc(dst_memory->get_desc());
}
} // namespace phi

PD_REGISTER_KERNEL(layer_norm,
OneDNN,
ONEDNN,
phi::LayerNormKernel,
float,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
Loading