@@ -159,27 +159,66 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
159159 perm = {0 , 2 , 1 };
160160 }
161161
162- // dx{SparseCsr} = dout{Dense} * y'{Dense}
162+ // cusparseSpGEMM only support 32-bit index.
163+ SparseCsrTensor dout_tmp;
164+ CastCsrKernel<T, Context>(
165+ dev_ctx, dout, DataType::INT32, dout.values ().dtype (), &dout_tmp);
166+
167+ // dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
163168 if (dx) {
164- // InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
165- EmptyLikeCsrKernel <T, Context>(dev_ctx, x, dx);
166- // cusparseSPGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
167- SparseCsrTensor trans_y;
168- TransposeCsrKernel <T, Context>(dev_ctx, y, perm, &trans_y );
169+ SparseCsrTensor x_tmp, dx_tmp;
170+ CastCsrKernel <T, Context>(
171+ dev_ctx, x, DataType::INT32, x. values (). dtype (), &x_tmp);
172+
173+ EmptyLikeCsrKernel <T, Context>(dev_ctx, x_tmp, &dx_tmp );
169174
170- sparse_blas.SPGEMM (
171- false , false , static_cast <T>(1 ), dout, trans_y, static_cast <T>(0 ), dx);
175+ // cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
176+ SparseCsrTensor trans_y, trans_y_tmp;
177+ TransposeCsrKernel<T, Context>(dev_ctx, y, perm, &trans_y);
178+ CastCsrKernel<T, Context>(dev_ctx,
179+ trans_y,
180+ DataType::INT32,
181+ trans_y.values ().dtype (),
182+ &trans_y_tmp);
183+
184+ sparse_blas.SPGEMM (false ,
185+ false ,
186+ static_cast <T>(1 ),
187+ dout_tmp,
188+ trans_y_tmp,
189+ static_cast <T>(0 ),
190+ &dx_tmp);
191+
192+ CastCsrKernel<T, Context>(
193+ dev_ctx, dx_tmp, DataType::INT64, dx_tmp.values ().dtype (), dx);
172194 }
173195
174- // dy{Dense } = x'{SparseCsr} * dout{Dense }
196+ // dy{SparseCsr } = x'{SparseCsr} * dout{SparseCsr }
175197 if (dy) {
176- // InferMeta of DenseTensor 'dy'
177- EmptyLikeCsrKernel <T, Context>(dev_ctx, y, dy);
178- SparseCsrTensor trans_x ;
179- TransposeCsrKernel <T, Context>(dev_ctx, x, perm, &trans_x );
198+ SparseCsrTensor y_tmp, dy_tmp;
199+ CastCsrKernel <T, Context>(
200+ dev_ctx, y, DataType::INT32, y. values (). dtype (), &y_tmp) ;
201+ EmptyLikeCsrKernel <T, Context>(dev_ctx, y_tmp, &dy_tmp );
180202
181- sparse_blas.SPGEMM (
182- false , false , static_cast <T>(1 ), trans_x, dout, static_cast <T>(0 ), dy);
203+ // cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
204+ SparseCsrTensor trans_x, trans_x_tmp;
205+ TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &trans_x);
206+ CastCsrKernel<T, Context>(dev_ctx,
207+ trans_x,
208+ DataType::INT32,
209+ trans_x.values ().dtype (),
210+ &trans_x_tmp);
211+
212+ sparse_blas.SPGEMM (false ,
213+ false ,
214+ static_cast <T>(1 ),
215+ trans_x_tmp,
216+ dout_tmp,
217+ static_cast <T>(0 ),
218+ &dy_tmp);
219+
220+ CastCsrKernel<T, Context>(
221+ dev_ctx, dy_tmp, DataType::INT64, dy_tmp.values ().dtype (), dy);
183222 }
184223#else
185224#ifdef PADDLE_WITH_CUDA
@@ -197,7 +236,7 @@ void MatmulCooCooGradKernel(const Context& dev_ctx,
197236 const SparseCooTensor& dout,
198237 SparseCooTensor* dx,
199238 SparseCooTensor* dy) {
200- // 'cusparseSPGEMM' only support CSR now, so use COO->CSR->COO,
239+ // cusparseSpGEMM only support CSR now, so use COO->CSR->COO
201240 SparseCsrTensor x_csr = CooToCsr<T, Context>(dev_ctx, x);
202241 SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
203242 SparseCsrTensor dout_csr = CooToCsr<T, Context>(dev_ctx, dout);
@@ -288,6 +327,7 @@ PD_REGISTER_KERNEL(matmul_csr_csr_grad,
288327 float ,
289328 double ) {
290329 kernel->InputAt (0 ).SetDataLayout (phi::DataLayout::SPARSE_CSR);
330+ kernel->InputAt (1 ).SetDataLayout (phi::DataLayout::SPARSE_CSR);
291331}
292332
293333PD_REGISTER_KERNEL (matmul_coo_coo_grad,
@@ -297,6 +337,7 @@ PD_REGISTER_KERNEL(matmul_coo_coo_grad,
297337 float ,
298338 double ) {
299339 kernel->InputAt (0 ).SetDataLayout (phi::DataLayout::SPARSE_COO);
340+ kernel->InputAt (1 ).SetDataLayout (phi::DataLayout::SPARSE_COO);
300341}
301342
302343PD_REGISTER_KERNEL (masked_matmul_csr_grad,
0 commit comments