Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions paddle/fluid/operators/lookup_table_v2_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
auto *table_var = context.InputVar("W");
PADDLE_ENFORCE_EQ(
(std::is_same<DeviceContext, platform::XPUDeviceContext>::value), true,
platform::errors::InvalidArgument("Unsupported place!"));
platform::errors::PreconditionNotMet("Unsupported place! only support "
"xpu place , please check your "
"place."));

PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"idx in LookupTableV2XPUKernel should be LoDTensor"));
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2XPUKernel should be LoDTensor."));

int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t ids_numel = ids_t->numel();
Expand All @@ -49,15 +52,19 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
auto *output = output_t->mutable_data<T>(context.GetPlace());
const int64_t *ids = ids_t->data<int64_t>();

PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::InvalidArgument(
"idx_numel in LookupTableV2XPUKernel should not "
"greater than int32_t::max."));
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2XPUKernel."));
int ids_numel_int32 = static_cast<int>(ids_numel);
int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D,
table, output, padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.",
r));
}
};

Expand All @@ -67,27 +74,28 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
DDim table_dim;
PADDLE_ENFORCE_EQ(
table_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"idx in LookupTableV2GradXPUKernel should be LoDTensor"));
PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true,
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2GradXPUKernel should be LoDTensor."));
table_dim = context.Input<LoDTensor>("W")->dims();

bool is_sparse = context.Attr<bool>("is_sparse");
PADDLE_ENFORCE_EQ(
is_sparse, false,
platform::errors::InvalidArgument(
"LookupTableV2GradXPUKernel dose NOT support is_sparse = True"));
"LookupTableV2GradXPUKernel dose NOT support is_sparse = True."));

auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));

int64_t ids_numel = ids_t->numel();
PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::InvalidArgument(
"idx_numel in LookupTableV2GradXPUKernel should not "
"greater than int32_t::max."));
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2GradXPUKernel."));
int ids_numel_int32 = static_cast<int>(ids_numel);
const int64_t *ids_data = ids_t->data<int64_t>();

Expand All @@ -100,13 +108,19 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero,
d_table_t->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));

r = xpu::embedding_backward<T, int64_t>(dev_ctx.x_context(),
ids_numel_int32, ids_data, D,
d_output_data, d_table_data);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.",
r));
}
};
#endif
Expand Down