Skip to content

Commit e726960

Browse files
author
qipengh
authored
[MLU] add lookup_table_v2 and unstack op (PaddlePaddle#42847)
1 parent 313f5d0 commit e726960

File tree

6 files changed

+512
-11
lines changed

6 files changed

+512
-11
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#include "paddle/fluid/framework/op_registry.h"
15+
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
constexpr int64_t kNoPadding = -1;
22+
23+
template <typename T>
24+
class LookupTableV2MLUKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext &ctx) const override {
27+
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor
28+
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
29+
auto *table_t = ctx.Input<framework::LoDTensor>("W");
30+
31+
auto *table_var = ctx.InputVar("W");
32+
PADDLE_ENFORCE_EQ(
33+
table_var->IsType<framework::LoDTensor>(), true,
34+
platform::errors::InvalidArgument("mlu only accept LoDTensor"));
35+
output_t->mutable_data<T>(ctx.GetPlace());
36+
37+
MLUCnnlTensorDesc ids_desc(*ids_t);
38+
MLUCnnlTensorDesc table_desc(*table_t);
39+
MLUCnnlTensorDesc output_desc(*output_t);
40+
41+
int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
42+
if (padding_idx == kNoPadding) {
43+
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
44+
table_desc.get(), GetBasePtr(table_t),
45+
ids_desc.get(), GetBasePtr(ids_t),
46+
output_desc.get(), GetBasePtr(output_t));
47+
} else {
48+
Tensor tmp_table_t(table_t->type());
49+
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());
50+
51+
Tensor index;
52+
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
53+
auto idx_value = static_cast<int32_t>(padding_idx);
54+
MLUCnnlTensorDesc index_desc(index);
55+
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(),
56+
GetBasePtr(&index));
57+
58+
auto update_dim = phi::make_ddim({1, table_t->dims()[1]});
59+
Tensor update;
60+
update.mutable_data<T>(update_dim, ctx.GetPlace());
61+
62+
auto update_value = static_cast<T>(0);
63+
MLUCnnlTensorDesc update_desc(update);
64+
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value,
65+
update_desc.get(), GetBasePtr(&update));
66+
67+
MLUCnnlTensorDesc tmp_table_desc(tmp_table_t);
68+
MLUCnnl::ScatterNd(
69+
ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index),
70+
update_desc.get(), GetBasePtr(&update), table_desc.get(),
71+
GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t));
72+
73+
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
74+
tmp_table_desc.get(), GetBasePtr(&tmp_table_t),
75+
ids_desc.get(), GetBasePtr(ids_t),
76+
output_desc.get(), GetBasePtr(output_t));
77+
}
78+
}
79+
};
80+
81+
template <typename T>
82+
class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
83+
public:
84+
void Compute(const framework::ExecutionContext &ctx) const override {
85+
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
86+
auto *output_grad_t =
87+
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
88+
auto *table_grad_t =
89+
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
90+
table_grad_t->mutable_data<T>(ctx.GetPlace());
91+
92+
int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));
93+
94+
Tensor ids_int32(ids_t->dtype());
95+
if (ids_t->dtype() != DataType::INT32) {
96+
ids_int32.mutable_data<int>(ids_t->dims(), ctx.GetPlace());
97+
MLUCnnlTensorDesc ids_desc(*ids_t);
98+
MLUCnnlTensorDesc ids_int32_desc(ids_int32);
99+
auto cast_type = GetCastDataType(ids_t->dtype(), DataType::INT32);
100+
MLUCnnl::Cast(ctx, cast_type, ids_desc.get(), GetBasePtr(ids_t),
101+
ids_int32_desc.get(), GetBasePtr(&ids_int32));
102+
} else {
103+
ids_int32 = *ids_t;
104+
}
105+
106+
MLUCnnlTensorDesc ids_int32_desc(ids_int32);
107+
MLUCnnlTensorDesc output_grad_desc(*output_grad_t);
108+
MLUCnnlTensorDesc table_grad_desc(*table_grad_t);
109+
110+
MLUCnnl::EmbeddingBackward(ctx, padding_idx, false, ids_int32_desc.get(),
111+
GetBasePtr(&ids_int32), output_grad_desc.get(),
112+
GetBasePtr(output_grad_t), table_grad_desc.get(),
113+
GetBasePtr(table_grad_t));
114+
}
115+
};
116+
} // namespace operators
117+
} // namespace paddle
118+
119+
namespace ops = paddle::operators;
120+
namespace plat = paddle::platform;
121+
122+
REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel<float>,
123+
ops::LookupTableV2MLUKernel<int>,
124+
ops::LookupTableV2MLUKernel<plat::float16>);
125+
126+
REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad,
127+
ops::LookupTableV2GradMLUKernel<float>,
128+
ops::LookupTableV2GradMLUKernel<int>,
129+
ops::LookupTableV2GradMLUKernel<plat::float16>);

paddle/fluid/operators/mlu/mlu_baseop.cc

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
3434
return cast_type;
3535
}
3636

37+
cnnlCastDataType_t GetCastDataType(const DataType& src_type,
38+
const DataType& dst_type) {
39+
return GetCastDataType(framework::TransToProtoVarType(src_type),
40+
framework::TransToProtoVarType(dst_type));
41+
}
42+
3743
bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type) {
3844
for (auto it = MLU_SUPPORTED_CAST_TYPE.begin();
3945
it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) {
@@ -2713,17 +2719,16 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
27132719
output_desc, output));
27142720
}
27152721

2716-
/* static */ void MLUCnnl::ScatterNd(const ExecutionContext& ctx,
2717-
const cnnlTensorDescriptor_t indices_desc,
2718-
const void* indices,
2719-
const cnnlTensorDescriptor_t updates_desc,
2720-
const void* updates,
2721-
const cnnlTensorDescriptor_t output_desc,
2722-
void* output) {
2722+
/* static */ void MLUCnnl::ScatterNd(
2723+
const ExecutionContext& ctx, cnnlScatterNdMode_t mode,
2724+
const cnnlTensorDescriptor_t indices_desc, const void* indices,
2725+
const cnnlTensorDescriptor_t updates_desc, const void* updates,
2726+
const cnnlTensorDescriptor_t input_desc, const void* input,
2727+
const cnnlTensorDescriptor_t output_desc, void* output) {
27232728
cnnlHandle_t handle = GetHandleFromCTX(ctx);
2724-
PADDLE_ENFORCE_MLU_SUCCESS(cnnlScatterNd(handle, indices_desc, indices,
2725-
updates_desc, updates, output_desc,
2726-
output));
2729+
PADDLE_ENFORCE_MLU_SUCCESS(
2730+
cnnlScatterNd_v2(handle, mode, indices_desc, indices, updates_desc,
2731+
updates, input_desc, input, output_desc, output));
27272732
}
27282733

27292734
/* static */ void MLUCnnl::BitWise(
@@ -2777,5 +2782,26 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
27772782
cnnlReciprocal(handle, input_desc, input, output_desc, output));
27782783
}
27792784

2785+
/* static */ void MLUCnnl::EmbeddingBackward(
2786+
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
2787+
const cnnlTensorDescriptor_t indices_desc, const void* indices,
2788+
const cnnlTensorDescriptor_t diff_desc, const void* diff,
2789+
const cnnlTensorDescriptor_t output_desc, void* output) {
2790+
cnnlHandle_t handle = GetHandleFromCTX(ctx);
2791+
2792+
size_t workspace_size;
2793+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetEmbeddingBackwardWorkspaceSize(
2794+
handle, diff_desc, output_desc, scale_grad_by_freq, &workspace_size));
2795+
2796+
auto& dev_ctx = GetDevCtxFromCTX(ctx);
2797+
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
2798+
{static_cast<int64_t>(workspace_size)}, dev_ctx);
2799+
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
2800+
2801+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingBackward(
2802+
handle, padding_idx, scale_grad_by_freq, indices_desc, indices, diff_desc,
2803+
diff, workspace_ptr, workspace_size, output_desc, output));
2804+
}
2805+
27802806
} // namespace operators
27812807
} // namespace paddle

paddle/fluid/operators/mlu/mlu_baseop.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ const std::map<std::pair<VT::Type, VT::Type>, cnnlCastDataType_t>
175175

176176
cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
177177
const VT::Type& dst_type);
178+
179+
cnnlCastDataType_t GetCastDataType(const DataType& src_type,
180+
const DataType& dst_type);
181+
178182
bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type);
179183

180184
cnnlDeviceType_t GetCnnlDev(int dev_ordinal);
@@ -1202,11 +1206,13 @@ class MLUCnnl {
12021206
const void* k, const int k_int,
12031207
const cnnlTensorDescriptor_t output_desc, void* output);
12041208

1205-
static void ScatterNd(const ExecutionContext& ctx,
1209+
static void ScatterNd(const ExecutionContext& ctx, cnnlScatterNdMode_t mode,
12061210
const cnnlTensorDescriptor_t indices_desc,
12071211
const void* indices,
12081212
const cnnlTensorDescriptor_t updates_desc,
12091213
const void* updates,
1214+
const cnnlTensorDescriptor_t input_desc,
1215+
const void* input,
12101216
const cnnlTensorDescriptor_t output_desc, void* output);
12111217

12121218
static void BitWise(const ExecutionContext& ctx,
@@ -1227,6 +1233,12 @@ class MLUCnnl {
12271233
const void* input,
12281234
const cnnlTensorDescriptor_t output_desc,
12291235
void* output);
1236+
1237+
static void EmbeddingBackward(
1238+
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
1239+
const cnnlTensorDescriptor_t indices_desc, const void* indices,
1240+
const cnnlTensorDescriptor_t diff_desc, const void* diff,
1241+
const cnnlTensorDescriptor_t output_desc, void* output);
12301242
};
12311243

12321244
template <typename T>
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
class UnStackMLUKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext &ctx) const override {
25+
auto *x = ctx.Input<Tensor>("X");
26+
auto out = ctx.MultiOutput<Tensor>("Y");
27+
int axis = ctx.Attr<int>("axis");
28+
if (axis < 0) axis += x->dims().size();
29+
int num = x->dims()[axis];
30+
31+
std::vector<MLUCnnlTensorDesc> out_descs;
32+
std::vector<cnnlTensorDescriptor_t> out_raw_descs;
33+
std::vector<void *> out_ptrs;
34+
std::vector<int64_t> new_dims = phi::vectorize(x->dims());
35+
new_dims[axis] = 1;
36+
for (int i = 0; i < num; i++) {
37+
out[i]->mutable_data<T>(ctx.GetPlace());
38+
out_descs.emplace_back(MLUCnnlTensorDesc(new_dims.size(), new_dims.data(),
39+
ToCnnlDataType<T>()));
40+
out_raw_descs.push_back(out_descs.back().get());
41+
out_ptrs.push_back(GetBasePtr(out[i]));
42+
}
43+
44+
MLUCnnlTensorDesc x_desc(*x);
45+
MLUCnnl::Split(ctx, num, axis, x_desc.get(), GetBasePtr(x),
46+
out_raw_descs.data(), out_ptrs.data());
47+
}
48+
};
49+
50+
template <typename T>
51+
class UnStackGradMLUKernel : public framework::OpKernel<T> {
52+
public:
53+
void Compute(const framework::ExecutionContext &ctx) const override {
54+
auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
55+
auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));
56+
int axis = ctx.Attr<int>("axis");
57+
if (axis < 0) axis += (x[0]->dims().size() + 1);
58+
int num = static_cast<int>(x.size());
59+
60+
std::vector<MLUCnnlTensorDesc> x_descs;
61+
std::vector<cnnlTensorDescriptor_t> x_raw_descs;
62+
std::vector<const void *> x_ptrs;
63+
for (int i = 0; i < num; i++) {
64+
if (x[i]->dims().size() != 0) {
65+
std::vector<int64_t> in_dims = phi::vectorize(x[i]->dims());
66+
in_dims.insert(in_dims.begin() + axis, 1);
67+
x_descs.emplace_back(MLUCnnlTensorDesc(in_dims.size(), in_dims.data(),
68+
ToCnnlDataType<T>()));
69+
} else {
70+
int input_dims = 1;
71+
x_descs.emplace_back(
72+
MLUCnnlTensorDesc(1, &input_dims, ToCnnlDataType<T>()));
73+
}
74+
x_raw_descs.push_back(x_descs.back().get());
75+
x_ptrs.push_back(GetBasePtr(x[i]));
76+
}
77+
y->mutable_data<T>(ctx.GetPlace());
78+
79+
MLUCnnlTensorDesc y_desc(*y);
80+
MLUCnnl::Concat(ctx, num, axis, x_raw_descs.data(), x_ptrs.data(),
81+
y_desc.get(), GetBasePtr(y));
82+
}
83+
};
84+
85+
} // namespace operators
86+
} // namespace paddle
87+
88+
namespace plat = paddle::platform;
89+
namespace ops = paddle::operators;
90+
91+
REGISTER_OP_MLU_KERNEL(unstack, ops::UnStackMLUKernel<float>,
92+
ops::UnStackMLUKernel<plat::float16>);
93+
94+
REGISTER_OP_MLU_KERNEL(unstack_grad, ops::UnStackGradMLUKernel<float>,
95+
ops::UnStackGradMLUKernel<plat::float16>);

0 commit comments

Comments
 (0)