Skip to content

Commit c8ff3d5

Browse files
use CastCsrKernel
1 parent 0ee3545 commit c8ff3d5

File tree

5 files changed

+159
-135
lines changed

5 files changed

+159
-135
lines changed

paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,66 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
159159
perm = {0, 2, 1};
160160
}
161161

162-
// dx{SparseCsr} = dout{Dense} * y'{Dense}
162+
// cusparseSpGEMM only support 32-bit index.
163+
SparseCsrTensor dout_tmp;
164+
CastCsrKernel<T, Context>(
165+
dev_ctx, dout, DataType::INT32, dout.values().dtype(), &dout_tmp);
166+
167+
// dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
163168
if (dx) {
164-
// InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
165-
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
166-
// cusparseSPGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
167-
SparseCsrTensor trans_y;
168-
TransposeCsrKernel<T, Context>(dev_ctx, y, perm, &trans_y);
169+
SparseCsrTensor x_tmp, dx_tmp;
170+
CastCsrKernel<T, Context>(
171+
dev_ctx, x, DataType::INT32, x.values().dtype(), &x_tmp);
172+
173+
EmptyLikeCsrKernel<T, Context>(dev_ctx, x_tmp, &dx_tmp);
169174

170-
sparse_blas.SPGEMM(
171-
false, false, static_cast<T>(1), dout, trans_y, static_cast<T>(0), dx);
175+
// cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
176+
SparseCsrTensor trans_y, trans_y_tmp;
177+
TransposeCsrKernel<T, Context>(dev_ctx, y, perm, &trans_y);
178+
CastCsrKernel<T, Context>(dev_ctx,
179+
trans_y,
180+
DataType::INT32,
181+
trans_y.values().dtype(),
182+
&trans_y_tmp);
183+
184+
sparse_blas.SPGEMM(false,
185+
false,
186+
static_cast<T>(1),
187+
dout_tmp,
188+
trans_y_tmp,
189+
static_cast<T>(0),
190+
&dx_tmp);
191+
192+
CastCsrKernel<T, Context>(
193+
dev_ctx, dx_tmp, DataType::INT64, dx_tmp.values().dtype(), dx);
172194
}
173195

174-
// dy{Dense} = x'{SparseCsr} * dout{Dense}
196+
// dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr}
175197
if (dy) {
176-
// InferMeta of DenseTensor 'dy'
177-
EmptyLikeCsrKernel<T, Context>(dev_ctx, y, dy);
178-
SparseCsrTensor trans_x;
179-
TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &trans_x);
198+
SparseCsrTensor y_tmp, dy_tmp;
199+
CastCsrKernel<T, Context>(
200+
dev_ctx, y, DataType::INT32, y.values().dtype(), &y_tmp);
201+
EmptyLikeCsrKernel<T, Context>(dev_ctx, y_tmp, &dy_tmp);
180202

181-
sparse_blas.SPGEMM(
182-
false, false, static_cast<T>(1), trans_x, dout, static_cast<T>(0), dy);
203+
// cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
204+
SparseCsrTensor trans_x, trans_x_tmp;
205+
TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &trans_x);
206+
CastCsrKernel<T, Context>(dev_ctx,
207+
trans_x,
208+
DataType::INT32,
209+
trans_x.values().dtype(),
210+
&trans_x_tmp);
211+
212+
sparse_blas.SPGEMM(false,
213+
false,
214+
static_cast<T>(1),
215+
trans_x_tmp,
216+
dout_tmp,
217+
static_cast<T>(0),
218+
&dy_tmp);
219+
220+
CastCsrKernel<T, Context>(
221+
dev_ctx, dy_tmp, DataType::INT64, dy_tmp.values().dtype(), dy);
183222
}
184223
#else
185224
#ifdef PADDLE_WITH_CUDA
@@ -197,7 +236,7 @@ void MatmulCooCooGradKernel(const Context& dev_ctx,
197236
const SparseCooTensor& dout,
198237
SparseCooTensor* dx,
199238
SparseCooTensor* dy) {
200-
// 'cusparseSPGEMM' only support CSR now, so use COO->CSR->COO,
239+
// cusparseSpGEMM only support CSR now, so use COO->CSR->COO
201240
SparseCsrTensor x_csr = CooToCsr<T, Context>(dev_ctx, x);
202241
SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
203242
SparseCsrTensor dout_csr = CooToCsr<T, Context>(dev_ctx, dout);
@@ -288,6 +327,7 @@ PD_REGISTER_KERNEL(matmul_csr_csr_grad,
288327
float,
289328
double) {
290329
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
330+
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
291331
}
292332

293333
PD_REGISTER_KERNEL(matmul_coo_coo_grad,
@@ -297,6 +337,7 @@ PD_REGISTER_KERNEL(matmul_coo_coo_grad,
297337
float,
298338
double) {
299339
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
340+
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
300341
}
301342

302343
PD_REGISTER_KERNEL(masked_matmul_csr_grad,

paddle/phi/kernels/sparse/gpu/matmul_kernel.cu

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License. */
2828
#include "paddle/phi/kernels/funcs/math_function_impl.h"
2929
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
3030
#include "paddle/phi/kernels/sparse/empty_kernel.h"
31+
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
3132
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
3233

3334
namespace phi {
@@ -158,21 +159,36 @@ void MatmulCsrCsrKernel(const Context& dev_ctx,
158159
"The shape of Input(x) and Input(y) is not suitable for matmul "
159160
"opetation, x_dim[-1] must be eaqual to y_dim[-2]."));
160161

162+
// cusparseSpGEMM only support 32-bit index.
163+
SparseCsrTensor x_tmp, y_tmp, out_tmp;
164+
CastCsrKernel<T, Context>(
165+
dev_ctx, x, DataType::INT32, x.values().dtype(), &x_tmp);
166+
CastCsrKernel<T, Context>(
167+
dev_ctx, y, DataType::INT32, y.values().dtype(), &y_tmp);
168+
161169
std::vector<int64_t> out_dim_vec = phi::vectorize(out->dims());
162170
int batch_size = 1;
163171
for (int i = 0; i < out_dim_vec.size() - 2; i++) {
164172
batch_size *= out_dim_vec[i];
165173
}
166-
167174
int64_t out_crows_size = batch_size * (xdim_vec[x_ndims - 2] + 1);
168175
DenseTensor out_crows = phi::Empty<int32_t>(dev_ctx, {out_crows_size});
169176
DenseTensor out_cols = phi::Empty<int32_t>(dev_ctx, {0});
170177
DenseTensor out_values = phi::Empty<T>(dev_ctx, {0});
171-
out->SetMember(out_crows, out_cols, out_values, out->dims());
178+
out_tmp.SetMember(out_crows, out_cols, out_values, out->dims());
172179

173180
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
174-
sparse_blas.SPGEMM(
175-
false, false, static_cast<T>(1), x, y, static_cast<T>(0), out);
181+
sparse_blas.SPGEMM(false,
182+
false,
183+
static_cast<T>(1),
184+
x_tmp,
185+
y_tmp,
186+
static_cast<T>(0),
187+
&out_tmp);
188+
189+
CastCsrKernel<T, Context>(
190+
dev_ctx, out_tmp, DataType::INT64, out_tmp.values().dtype(), out);
191+
176192
#else
177193
#ifdef PADDLE_WITH_CUDA
178194
PADDLE_THROW(phi::errors::Unimplemented(
@@ -307,6 +323,7 @@ PD_REGISTER_KERNEL(matmul_coo_coo,
307323
float,
308324
double) {
309325
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
326+
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
310327
}
311328

312329
PD_REGISTER_KERNEL(matmul_csr_csr,
@@ -316,6 +333,7 @@ PD_REGISTER_KERNEL(matmul_csr_csr,
316333
float,
317334
double) {
318335
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
336+
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
319337
}
320338

321339
PD_REGISTER_KERNEL(masked_matmul_csr,

0 commit comments

Comments
 (0)