@@ -496,7 +496,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
496496
497497/* ************ SPARSE*SPARSE->SPARSE MATMUL ************/
498498template <typename T>
499- __global__ void GetCsrBatchNNZ (const int32_t * crow_data,
499+ __global__ void GetCsrBatchNnz (const int32_t * crow_data,
500500 int64_t rows,
501501 int32_t * batch_nnz) {
502502 int64_t i = static_cast <int64_t >(threadIdx.x );
@@ -521,14 +521,6 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
521521 out_crows_meta.set_dims (mat_a.crows ().dims ());
522522 dev_ctx_.template Alloc <int32_t >(mat_out_crows);
523523
524- MetaTensor out_cols_meta (mat_out_cols);
525- out_cols_meta.set_dtype (phi::DataType::INT32);
526- dev_ctx_.template Alloc <int32_t >(mat_out_cols);
527-
528- MetaTensor out_values_meta (mat_out_values);
529- out_values_meta.set_dtype (mat_a.values ().dtype ());
530- dev_ctx_.template Alloc <T>(mat_out_values);
531-
532524 std::vector<int64_t > a_dim_vec = common::vectorize (mat_a.dims ());
533525 auto a_ndims = a_dim_vec.size ();
534526 const int64_t a_rows = a_dim_vec[a_ndims - 2 ];
@@ -544,43 +536,49 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
544536 const int64_t b_cols = b_dim_vec[b_ndims - 1 ];
545537
546538 // cusparseSpGEMM only support 32-bit indices.
547- DenseTensor a_crows_int, a_cols_int, b_crows_int, b_cols_int;
548- const int32_t *a_crows_data, *a_cols_data, *b_crows_data, *b_cols_data;
539+ const int32_t *a_crows_data = nullptr , *a_cols_data = nullptr ,
540+ *b_crows_data = nullptr , *b_cols_data = nullptr ;
541+ std::shared_ptr<DenseTensor> a_crows_int = nullptr , a_cols_int = nullptr ,
542+ b_crows_int = nullptr , b_cols_int = nullptr ;
549543
550544 if (mat_a.crows ().dtype () == phi::DataType::INT32) {
551545 a_crows_data = mat_a.crows ().data <int32_t >();
552546 a_cols_data = mat_a.cols ().data <int32_t >();
553547 } else {
554- phi::MetaTensor crows_meta (&a_crows_int);
548+ a_crows_int = std::make_shared<DenseTensor>();
549+ a_cols_int = std::make_shared<DenseTensor>();
550+ phi::MetaTensor crows_meta (a_crows_int.get ());
555551 crows_meta.set_dims (mat_a.crows ().dims ());
556- phi::MetaTensor cols_meta (& a_cols_int);
552+ phi::MetaTensor cols_meta (a_cols_int. get () );
557553 cols_meta.set_dims (mat_a.cols ().dims ());
558554
559555 phi::CastKernel<int64_t >(
560- dev_ctx_, mat_a.crows (), phi::DataType::INT32, & a_crows_int);
556+ dev_ctx_, mat_a.crows (), phi::DataType::INT32, a_crows_int. get () );
561557 phi::CastKernel<int64_t >(
562- dev_ctx_, mat_a.cols (), phi::DataType::INT32, & a_cols_int);
558+ dev_ctx_, mat_a.cols (), phi::DataType::INT32, a_cols_int. get () );
563559
564- a_crows_data = a_crows_int. data <int32_t >();
565- a_cols_data = a_cols_int. data <int32_t >();
560+ a_crows_data = a_crows_int-> data <int32_t >();
561+ a_cols_data = a_cols_int-> data <int32_t >();
566562 }
567563
568564 if (mat_b.crows ().dtype () == phi::DataType::INT32) {
569565 b_crows_data = mat_b.crows ().data <int32_t >();
570566 b_cols_data = mat_b.cols ().data <int32_t >();
571567 } else {
572- phi::MetaTensor crows_meta (&b_crows_int);
568+ b_crows_int = std::make_shared<DenseTensor>();
569+ b_cols_int = std::make_shared<DenseTensor>();
570+ phi::MetaTensor crows_meta (b_crows_int.get ());
573571 crows_meta.set_dims (mat_b.crows ().dims ());
574- phi::MetaTensor cols_meta (& b_cols_int);
572+ phi::MetaTensor cols_meta (b_cols_int. get () );
575573 cols_meta.set_dims (mat_b.cols ().dims ());
576574
577575 phi::CastKernel<int64_t >(
578- dev_ctx_, mat_b.crows (), phi::DataType::INT32, & b_crows_int);
576+ dev_ctx_, mat_b.crows (), phi::DataType::INT32, b_crows_int. get () );
579577 phi::CastKernel<int64_t >(
580- dev_ctx_, mat_b.cols (), phi::DataType::INT32, & b_cols_int);
578+ dev_ctx_, mat_b.cols (), phi::DataType::INT32, b_cols_int. get () );
581579
582- b_crows_data = b_crows_int. data <int32_t >();
583- b_cols_data = b_cols_int. data <int32_t >();
580+ b_crows_data = b_crows_int-> data <int32_t >();
581+ b_cols_data = b_cols_int-> data <int32_t >();
584582 }
585583
586584 const T* a_values_data = mat_a.values ().data <T>();
@@ -601,15 +599,15 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
601599 phi::Stream (reinterpret_cast <phi::StreamId>(dev_ctx_.stream ())));
602600 void * tmp_buffer_ptr = tmp_buffer->ptr ();
603601
604- GetCsrBatchNNZ <T><<<1 , batch_size, 0 , dev_ctx_.stream ()>>>(
602+ GetCsrBatchNnz <T><<<1 , batch_size, 0 , dev_ctx_.stream ()>>>(
605603 a_crows_data, a_rows, static_cast <int32_t *>(tmp_buffer_ptr));
606604 phi::backends::gpu::GpuMemcpyAsync (a_batch_nnz_vec.data (),
607605 tmp_buffer_ptr,
608606 batch_size * sizeof (int32_t ),
609607 gpuMemcpyDeviceToHost,
610608 dev_ctx_.stream ());
611609
612- GetCsrBatchNNZ <T><<<1 , batch_size, 0 , dev_ctx_.stream ()>>>(
610+ GetCsrBatchNnz <T><<<1 , batch_size, 0 , dev_ctx_.stream ()>>>(
613611 b_crows_data, b_rows, static_cast <int32_t *>(tmp_buffer_ptr));
614612 phi::backends::gpu::GpuMemcpyAsync (b_batch_nnz_vec.data (),
615613 tmp_buffer_ptr,
@@ -815,8 +813,7 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
815813 *(mat_out->mutable_values ()) = std::move (out_batch_values_vec[0 ]);
816814
817815 } else {
818- std::vector<const DenseTensor*> cols_vec;
819- std::vector<const DenseTensor*> values_vec;
816+ std::vector<const DenseTensor*> cols_vec, values_vec;
820817
821818 for (int i = 0 ; i < batch_size; ++i) {
822819 cols_vec.push_back (&out_batch_cols_vec[i]);
0 commit comments