Skip to content

Commit ecddcf0

Browse files
committed
remove mkldnn_helper code
1 parent a4ffbdd commit ecddcf0

3 files changed

Lines changed: 354 additions & 303 deletions

File tree

paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class FCMKLDNNHandler
122122
post_operations.append_eltwise(
123123
activation_scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
124124
}
125-
platform::AppendActivation(ctx, post_operations, activation_scale);
125+
AppendActivation(ctx, post_operations, activation_scale);
126126

127127
if (ctx.HasAttr("fused_output_scale")) {
128128
float scale_alpha = ctx.Attr<float>("fused_output_scale");
@@ -154,6 +154,59 @@ class FCMKLDNNHandler
154154
}
155155
}
156156

157+
void AppendActivation(const ExecutionContext& ctx,
158+
dnnl::post_ops& post_ops, // NOLINT
159+
float activation_scale = 1.0f) {
160+
const auto invalid_attribute =
161+
ctx.HasAttr("fuse_activation")
162+
? ctx.Attr<std::string>("fuse_activation").empty()
163+
: true;
164+
if (invalid_attribute) return;
165+
166+
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
167+
const auto fuse_alpha =
168+
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
169+
const auto fuse_beta =
170+
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
171+
172+
if (fuse_activation == "hard_sigmoid") {
173+
post_ops.append_eltwise(activation_scale,
174+
dnnl::algorithm::eltwise_linear,
175+
fuse_alpha,
176+
fuse_beta);
177+
post_ops.append_eltwise(
178+
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
179+
} else {
180+
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
181+
{"abs", dnnl::algorithm::eltwise_abs},
182+
{"clip", dnnl::algorithm::eltwise_clip},
183+
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
184+
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
185+
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
186+
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
187+
{"leaky_relu", dnnl::algorithm::eltwise_relu},
188+
{"mish", dnnl::algorithm::eltwise_mish},
189+
{"relu", dnnl::algorithm::eltwise_relu},
190+
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
191+
{"sigmoid", dnnl::algorithm::eltwise_logistic},
192+
{"sqrt", dnnl::algorithm::eltwise_sqrt},
193+
{"swish", dnnl::algorithm::eltwise_swish},
194+
{"tanh", dnnl::algorithm::eltwise_tanh}};
195+
196+
const auto& activation_type = activation_map.find(fuse_activation);
197+
198+
PADDLE_ENFORCE_NE(
199+
activation_type,
200+
activation_map.end(),
201+
platform::errors::InvalidArgument(
202+
"Activation '%s' not found in oneDNN algorithms mapper",
203+
fuse_activation));
204+
205+
post_ops.append_eltwise(
206+
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
207+
}
208+
}
209+
157210
// Correct output scale, to take into account scaling of input and weights
158211
// Since the data that comes out of input and weight multiplication is
159212
// scaled with its own scales, this data needs to be divided by
@@ -396,6 +449,72 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
396449
}
397450
}
398451

452+
void SetOutMemDescWithUnsqueeze2FuseSupport(
453+
const framework::ExecutionContext& ctx,
454+
phi::DenseTensor* out,
455+
const dnnl::memory::desc& out_md) const {
456+
const std::vector<int>& fused_unsqueeze2_axes =
457+
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
458+
const std::vector<int64_t>& op_tz = out_md.dims();
459+
std::vector<int64_t> unsqueezed_op_tz(
460+
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
461+
462+
for (const auto& axis : fused_unsqueeze2_axes) {
463+
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
464+
unsqueezed_op_tz[positive_axis] = 1;
465+
}
466+
467+
int j = 0;
468+
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
469+
if (unsqueezed_op_tz[i] == 0) {
470+
unsqueezed_op_tz[i] = op_tz[j++];
471+
}
472+
}
473+
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
474+
out->Resize(phi::make_ddim(unsqueezed_op_tz));
475+
}
476+
477+
void SetOutMemDescWithReshape2FuseSupport(
478+
const framework::ExecutionContext& ctx,
479+
phi::DenseTensor* out,
480+
const dnnl::memory::desc& out_md) const {
481+
std::vector<int64_t> fused_reshape2_shape(
482+
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
483+
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
484+
485+
const int out_shape_numel = out->numel();
486+
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
487+
fused_reshape2_shape.end(),
488+
1,
489+
std::multiplies<int64_t>());
490+
491+
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
492+
if (fused_reshape2_shape[i] == -1) {
493+
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
494+
break;
495+
}
496+
}
497+
498+
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
499+
out->Resize(phi::make_ddim(fused_reshape2_shape));
500+
}
501+
502+
void SetOutMemDescWithLogicalLayoutFusesSupport(
503+
const framework::ExecutionContext& ctx,
504+
phi::DenseTensor* out,
505+
const dnnl::memory::desc& out_md) const {
506+
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
507+
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
508+
} else if (ctx.HasAttr("fused_reshape2_shape")) {
509+
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
510+
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
511+
out->set_mem_desc(out_md);
512+
out->Resize(phi::make_ddim(out_md.dims()));
513+
} else {
514+
out->set_mem_desc(out_md);
515+
}
516+
}
517+
399518
template <typename T_out, typename T_w>
400519
void RunKernel(const framework::ExecutionContext& ctx) const {
401520
const auto& dev_ctx = ctx.template device_context<OneDNNContext>();
@@ -504,7 +623,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
504623
dev_ctx.SetBlob(cache_key, ip_cache);
505624
}
506625

507-
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
626+
SetOutMemDescWithLogicalLayoutFusesSupport(
508627
ctx,
509628
out,
510629
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims())));

paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ namespace {
2121
using dnnl::memory;
2222
using paddle::framework::ExecutionContext;
2323
using paddle::framework::GradVarName;
24-
using paddle::platform::MatMulV2MKLDNNHandler;
2524
using phi::OneDNNContext;
2625
using phi::vectorize;
2726
using phi::funcs::OneDNNGetDataType;
@@ -82,6 +81,239 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
8281
return input_dims;
8382
}
8483

84+
template <typename XT, typename YT, typename OT>
85+
class MatMulV2MKLDNNHandler
86+
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
87+
public:
88+
MatMulV2MKLDNNHandler(const ExecutionContext &ctx,
89+
const dnnl::engine engine,
90+
paddle::platform::Place cpu_place,
91+
const std::vector<int64_t> &x_org_dims,
92+
bool trans_x,
93+
const std::vector<int64_t> &y_org_dims,
94+
bool trans_y,
95+
bool is_output_fused,
96+
const std::vector<int64_t> &x_strides_override,
97+
const std::vector<int64_t> &y_strides_override)
98+
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
99+
cpu_place) {
100+
// M X K * K X N
101+
std::vector<int64_t> x_dims(x_org_dims);
102+
std::vector<int64_t> y_dims(y_org_dims);
103+
104+
const int MB_idx = x_dims.size() - 3;
105+
const int H_idx = x_dims.size() - 2;
106+
const int W_idx = x_dims.size() - 1;
107+
108+
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
109+
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
110+
111+
const memory::dim M = x_dims[H_idx];
112+
const memory::dim K = x_dims[W_idx];
113+
const memory::dim N = y_dims[W_idx];
114+
115+
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
116+
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
117+
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
118+
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
119+
120+
x_strides.reserve(x_dims.size());
121+
y_strides.reserve(x_dims.size());
122+
out_strides.reserve(x_dims.size());
123+
124+
if (!x_strides_override.empty()) {
125+
x_strides = x_strides_override;
126+
} else {
127+
if (!trans_x) {
128+
x_strides.insert(x_strides.end(), {M * K, K, 1});
129+
} else {
130+
x_strides.insert(x_strides.end(), {M * K, 1, M});
131+
}
132+
}
133+
134+
if (!y_strides_override.empty()) {
135+
y_strides = y_strides_override;
136+
} else {
137+
if (!trans_y) {
138+
y_strides.insert(y_strides.end(), {N * K, N, 1});
139+
} else {
140+
y_strides.insert(y_strides.end(), {N * K, 1, K});
141+
}
142+
}
143+
144+
out_strides.insert(out_strides.end(), {M * N, N, 1});
145+
out_ddims.insert(out_ddims.end(),
146+
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
147+
148+
for (int i = x_dims.size() - 4; i >= 0; --i) {
149+
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
150+
if (x_strides_override.empty()) {
151+
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
152+
}
153+
if (y_strides_override.empty()) {
154+
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
155+
}
156+
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
157+
}
158+
159+
// TODO(jczaja): Why not for int8??
160+
if (!phi::funcs::is_int8<OT>() && is_output_fused) {
161+
out_strides = FakeTransposeStrides(out_ddims);
162+
}
163+
164+
auto x_md =
165+
memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides);
166+
auto y_md =
167+
memory::desc(y_dims, phi::funcs::OneDNNGetDataType<YT>(), y_strides);
168+
auto out_md = memory::desc(
169+
out_ddims, phi::funcs::OneDNNGetDataType<OT>(), out_strides);
170+
171+
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
172+
173+
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
174+
}
175+
176+
void AppendActivation(const ExecutionContext &ctx,
177+
dnnl::post_ops &post_ops, // NOLINT
178+
float activation_scale = 1.0f) {
179+
const auto invalid_attribute =
180+
ctx.HasAttr("fuse_activation")
181+
? ctx.Attr<std::string>("fuse_activation").empty()
182+
: true;
183+
if (invalid_attribute) return;
184+
185+
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
186+
const auto fuse_alpha =
187+
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
188+
const auto fuse_beta =
189+
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
190+
191+
if (fuse_activation == "hard_sigmoid") {
192+
post_ops.append_eltwise(activation_scale,
193+
dnnl::algorithm::eltwise_linear,
194+
fuse_alpha,
195+
fuse_beta);
196+
post_ops.append_eltwise(
197+
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
198+
} else {
199+
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
200+
{"abs", dnnl::algorithm::eltwise_abs},
201+
{"clip", dnnl::algorithm::eltwise_clip},
202+
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
203+
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
204+
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
205+
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
206+
{"leaky_relu", dnnl::algorithm::eltwise_relu},
207+
{"mish", dnnl::algorithm::eltwise_mish},
208+
{"relu", dnnl::algorithm::eltwise_relu},
209+
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
210+
{"sigmoid", dnnl::algorithm::eltwise_logistic},
211+
{"sqrt", dnnl::algorithm::eltwise_sqrt},
212+
{"swish", dnnl::algorithm::eltwise_swish},
213+
{"tanh", dnnl::algorithm::eltwise_tanh}};
214+
215+
const auto &activation_type = activation_map.find(fuse_activation);
216+
217+
PADDLE_ENFORCE_NE(
218+
activation_type,
219+
activation_map.end(),
220+
phi::errors::InvalidArgument(
221+
"Activation '%s' not found in oneDNN algorithms mapper",
222+
fuse_activation));
223+
224+
post_ops.append_eltwise(
225+
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
226+
}
227+
}
228+
229+
float ComputeOutputScale(const ExecutionContext &ctx) {
230+
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
231+
if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") &&
232+
ctx.HasAttr("Scale_out")) {
233+
float scale_x = ctx.Attr<float>("Scale_x");
234+
float scale_y = ctx.Attr<float>("Scale_y");
235+
bool force_fp32_out = ctx.HasAttr("force_fp32_output")
236+
? ctx.Attr<bool>("force_fp32_output")
237+
: false;
238+
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
239+
alpha *= scale_out / (scale_x * scale_y);
240+
}
241+
return alpha;
242+
}
243+
244+
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) {
245+
dnnl::primitive_attr matmul_attrs;
246+
dnnl::post_ops post_operations;
247+
248+
float scale_out = ComputeOutputScale(ctx);
249+
if (scale_out != 1.0f) {
250+
matmul_attrs.set_output_scales(0, {scale_out});
251+
}
252+
253+
if (ctx.HasInput("ResidualData")) {
254+
auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
255+
auto residual_data_tz = phi::vectorize(residual_data->dims());
256+
auto residual_data_md = memory::desc(residual_data_tz,
257+
phi::funcs::OneDNNGetDataType<OT>(),
258+
dnnl::memory::format_tag::any);
259+
post_operations.append_binary(dnnl::algorithm::binary_add,
260+
residual_data_md);
261+
if (ctx.HasAttr("Scale_in_eltwise")) {
262+
float sum_scale = scale_out / ctx.Attr<float>("Scale_in_eltwise");
263+
post_operations.append_sum(sum_scale);
264+
}
265+
}
266+
267+
AppendActivation(ctx, post_operations);
268+
269+
if (ctx.HasAttr("fused_output_scale")) {
270+
float scale_alpha = ctx.Attr<float>("fused_output_scale");
271+
post_operations.append_eltwise(
272+
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
273+
}
274+
275+
matmul_attrs.set_post_ops(post_operations);
276+
return matmul_attrs;
277+
}
278+
279+
std::vector<int64_t> FakeTransposeStrides(
280+
const std::vector<int64_t> &matmul_out_dims) const {
281+
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
282+
// transpose axis are: {0, 2, 1, 3}
283+
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
284+
std::vector<int64_t> fake_strides(transpose_axis.size());
285+
int ndims = static_cast<int>(transpose_axis.size());
286+
287+
int total_stride = 1;
288+
289+
for (int i = ndims - 1; i >= 0; --i) {
290+
fake_strides[transpose_axis[i]] = total_stride;
291+
total_stride *= matmul_out_dims[transpose_axis[i]];
292+
}
293+
294+
return fake_strides;
295+
}
296+
297+
std::shared_ptr<memory> AcquireWeightsMemory(const phi::DenseTensor *input) {
298+
const YT *input_data = input->data<YT>();
299+
return this->AcquireMemoryFromPrimitive(
300+
this->fwd_pd_->weights_desc(),
301+
phi::funcs::to_void_cast<YT>(input_data));
302+
}
303+
304+
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor *output) {
305+
// We cannot use base AcquireDstMemory as it makes an allocation request
306+
// base on DST memory primitive size. This is fine in general, but in MatMul
307+
// we have primitive that covers only one batch of Data and then shift
308+
// pointer for every new batch. Hence phi::DenseTensor size is bigger that
309+
// dst memory primitive size. So would we request less memory that is there
310+
// and it triggers an assertion. So as there is no 'any' format here we can
311+
// leave default size of phi::DenseTensor as computed in ComputeInferShape
312+
OT *ptr = output->mutable_data<OT>(this->place_);
313+
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
314+
}
315+
};
316+
85317
template <typename XT, typename YT, typename OT>
86318
class MatMulMKLDNNHandler
87319
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {

0 commit comments

Comments
 (0)