Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions paddle/fluid/operators/lookup_table_v2_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
auto *table_var = context.InputVar("W");
PADDLE_ENFORCE_EQ(
(std::is_same<DeviceContext, platform::XPUDeviceContext>::value), true,
platform::errors::InvalidArgument("Unsupported place!"));
platform::errors::PreconditionNotMet(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是提示下这里只支持XPUPlace更清晰

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"Unsupported place! please check your place"));

PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"idx in LookupTableV2XPUKernel should be LoDTensor"));
"Tensor holds the wrong type,idx in "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议报错信息统一首字母大写,结尾加句点

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中间的逗号后面建议加空格

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"LookupTableV2XPUKernel should be LoDTensor"));

int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t ids_numel = ids_t->numel();
Expand All @@ -50,14 +52,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
const int64_t *ids = ids_t->data<int64_t>();

PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::InvalidArgument(
"idx_numel in LookupTableV2XPUKernel should not "
"greater than int32_t::max."));
platform::errors::OutOfRange(
"idx_numel greater than int32_t::max. please check "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx还是ids?idx_numel是用户输入的变量吗?如果不是的话,就是我们内部C++变量,建议改成用户更容易接受的描述,比如把变量缩写拆开,比如 number of element

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,拆成了number of ids

"idx_numel in LookupTableV2XPUKernel"));
int ids_numel_int32 = static_cast<int>(ids_numel);
int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D,
table, output, padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
}
};

Expand All @@ -67,10 +72,10 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
DDim table_dim;
PADDLE_ENFORCE_EQ(
table_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"idx in LookupTableV2GradXPUKernel should be LoDTensor"));
PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,idx in "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能不太好改,但是需要斟酌一下,Tensor holds the wrong type,但正确type确实LoDTensor?这不是同级概念,令人困惑

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"LookupTableV2GradXPUKernel should be LoDTensor"));
table_dim = context.Input<LoDTensor>("W")->dims();

bool is_sparse = context.Attr<bool>("is_sparse");
Expand All @@ -85,9 +90,9 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {

int64_t ids_numel = ids_t->numel();
PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::InvalidArgument(
"idx_numel in LookupTableV2GradXPUKernel should not "
"greater than int32_t::max."));
platform::errors::OutOfRange(
"idx_numel greater than int32_t::max. please check "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

"idx_numel in LookupTableV2GradXPUKernel"));
int ids_numel_int32 = static_cast<int>(ids_numel);
const int64_t *ids_data = ids_t->data<int64_t>();

Expand All @@ -100,13 +105,19 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero,
d_table_t->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));

r = xpu::embedding_backward<T, int64_t>(dev_ctx.x_context(),
ids_numel_int32, ids_data, D,
d_output_data, d_table_data);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
}
};
#endif
Expand Down