-
Notifications
You must be signed in to change notification settings - Fork 6k
lookup_table_v2_op_xpu report errors;test=kunlun #28064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,11 +31,13 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> { | |
| auto *table_var = context.InputVar("W"); | ||
| PADDLE_ENFORCE_EQ( | ||
| (std::is_same<DeviceContext, platform::XPUDeviceContext>::value), true, | ||
| platform::errors::InvalidArgument("Unsupported place!")); | ||
| platform::errors::PreconditionNotMet( | ||
| "Unsupported place! please check your place")); | ||
|
|
||
| PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true, | ||
| platform::errors::InvalidArgument( | ||
| "idx in LookupTableV2XPUKernel should be LoDTensor")); | ||
| "Tensor holds the wrong type,idx in " | ||
|
||
| "LookupTableV2XPUKernel should be LoDTensor")); | ||
|
|
||
| int64_t padding_idx = context.Attr<int64_t>("padding_idx"); | ||
| int64_t ids_numel = ids_t->numel(); | ||
|
|
@@ -50,14 +52,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> { | |
| const int64_t *ids = ids_t->data<int64_t>(); | ||
|
|
||
| PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true, | ||
| platform::errors::InvalidArgument( | ||
| "idx_numel in LookupTableV2XPUKernel should not " | ||
| "greater than int32_t::max.")); | ||
| platform::errors::OutOfRange( | ||
| "idx_numel greater than int32_t::max. please check " | ||
|
||
| "idx_numel in LookupTableV2XPUKernel")); | ||
| int ids_numel_int32 = static_cast<int>(ids_numel); | ||
| int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D, | ||
| table, output, padding_idx); | ||
| PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, | ||
| platform::errors::InvalidArgument("XPU kernel error!")); | ||
| platform::errors::External( | ||
| "XPU API return wrong value[%d], please check where " | ||
| "Baidu Kunlun Card is properly installed.", | ||
| r)); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -67,10 +72,10 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> { | |
| void Compute(const framework::ExecutionContext &context) const override { | ||
| auto *table_var = context.InputVar("W"); | ||
| DDim table_dim; | ||
| PADDLE_ENFORCE_EQ( | ||
| table_var->IsType<LoDTensor>(), true, | ||
| platform::errors::InvalidArgument( | ||
| "idx in LookupTableV2GradXPUKernel should be LoDTensor")); | ||
| PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true, | ||
| platform::errors::InvalidArgument( | ||
| "Tensor holds the wrong type,idx in " | ||
|
||
| "LookupTableV2GradXPUKernel should be LoDTensor")); | ||
| table_dim = context.Input<LoDTensor>("W")->dims(); | ||
|
|
||
| bool is_sparse = context.Attr<bool>("is_sparse"); | ||
|
|
@@ -85,9 +90,9 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> { | |
|
|
||
| int64_t ids_numel = ids_t->numel(); | ||
| PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true, | ||
| platform::errors::InvalidArgument( | ||
| "idx_numel in LookupTableV2GradXPUKernel should not " | ||
| "greater than int32_t::max.")); | ||
| platform::errors::OutOfRange( | ||
| "idx_numel greater than int32_t::max. please check " | ||
|
||
| "idx_numel in LookupTableV2GradXPUKernel")); | ||
| int ids_numel_int32 = static_cast<int>(ids_numel); | ||
| const int64_t *ids_data = ids_t->data<int64_t>(); | ||
|
|
||
|
|
@@ -100,13 +105,19 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> { | |
| int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero, | ||
| d_table_t->numel() * sizeof(T)); | ||
| PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, | ||
| platform::errors::InvalidArgument("XPU kernel error!")); | ||
| platform::errors::External( | ||
| "XPU API return wrong value[%d], please check where " | ||
| "Baidu Kunlun Card is properly installed.", | ||
| r)); | ||
|
|
||
| r = xpu::embedding_backward<T, int64_t>(dev_ctx.x_context(), | ||
| ids_numel_int32, ids_data, D, | ||
| d_output_data, d_table_data); | ||
| PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, | ||
| platform::errors::InvalidArgument("XPU kernel error!")); | ||
| platform::errors::External( | ||
| "XPU API return wrong value[%d], please check where " | ||
| "Baidu Kunlun Card is properly installed.", | ||
| r)); | ||
| } | ||
| }; | ||
| #endif | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是提示下这里只支持XPUPlace更清晰
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改