Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out->set_format(x_format_tag);

paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(),
out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims);
dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x,
0.0f, 1.0f, x_vec_dims);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
Expand Down Expand Up @@ -131,8 +131,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc()));
} else {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dx, dx_vec_dims);

auto src_memory_p = handler.AcquireSrcMemory(dout);
auto dst_memory_p = handler.AcquireDstMemory(dx);
Expand Down
106 changes: 49 additions & 57 deletions paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,58 +83,52 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,

template <typename T>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulMKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale, const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(
x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor(
y->dims(), 0, trans_y);

memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;

memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;

memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};

memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);

dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});

this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y);

memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;

memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;

memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};

memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);

dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});

this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}

std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};

Expand Down Expand Up @@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
Tensor* out, int execution_number) const {
Tensor* out) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Expand All @@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(

float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;

MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out, alpha,
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
&y_combined, trans_y, out, alpha);

const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
Expand Down Expand Up @@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {

if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
true, false, dx, 0);
true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
true, false, dy, 1);
true, false, dy);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
&dout, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
&dout, false, true, dy);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
&y, false, true, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
false, true, dy, 1);
false, true, dy);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
&y, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
false, true, dy, 1);
false, true, dy);
}

if (dx) {
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
bool is_fold_init_dims_y, Tensor* out) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
Expand Down
122 changes: 57 additions & 65 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,79 +31,72 @@ using paddle::framework::GradVarName;

template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulV2MKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;

if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);

const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];

std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);

x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());

if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
const std::vector<int64_t>& y_org_dims, bool trans_y)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;

if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);

const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];

std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);

x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());

if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}

if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}

out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});

for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md =
memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);

this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}

std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};

Expand All @@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y);

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
Expand Down Expand Up @@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims);

auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
Expand Down
Loading