Skip to content

Commit 1f34f1c

Browse files
fix
1 parent e01251b commit 1f34f1c

File tree

1 file changed

+24
-27
lines changed

1 file changed

+24
-27
lines changed

paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
496496

497497
/************* SPARSE*SPARSE->SPARSE MATMUL ************/
498498
template <typename T>
499-
__global__ void GetCsrBatchNNZ(const int32_t* crow_data,
499+
__global__ void GetCsrBatchNnz(const int32_t* crow_data,
500500
int64_t rows,
501501
int32_t* batch_nnz) {
502502
int64_t i = static_cast<int64_t>(threadIdx.x);
@@ -521,14 +521,6 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
521521
out_crows_meta.set_dims(mat_a.crows().dims());
522522
dev_ctx_.template Alloc<int32_t>(mat_out_crows);
523523

524-
MetaTensor out_cols_meta(mat_out_cols);
525-
out_cols_meta.set_dtype(phi::DataType::INT32);
526-
dev_ctx_.template Alloc<int32_t>(mat_out_cols);
527-
528-
MetaTensor out_values_meta(mat_out_values);
529-
out_values_meta.set_dtype(mat_a.values().dtype());
530-
dev_ctx_.template Alloc<T>(mat_out_values);
531-
532524
std::vector<int64_t> a_dim_vec = common::vectorize(mat_a.dims());
533525
auto a_ndims = a_dim_vec.size();
534526
const int64_t a_rows = a_dim_vec[a_ndims - 2];
@@ -544,43 +536,49 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
544536
const int64_t b_cols = b_dim_vec[b_ndims - 1];
545537

546538
// cusparseSpGEMM only support 32-bit indices.
547-
DenseTensor a_crows_int, a_cols_int, b_crows_int, b_cols_int;
548-
const int32_t *a_crows_data, *a_cols_data, *b_crows_data, *b_cols_data;
539+
const int32_t *a_crows_data = nullptr, *a_cols_data = nullptr,
540+
*b_crows_data = nullptr, *b_cols_data = nullptr;
541+
std::shared_ptr<DenseTensor> a_crows_int = nullptr, a_cols_int = nullptr,
542+
b_crows_int = nullptr, b_cols_int = nullptr;
549543

550544
if (mat_a.crows().dtype() == phi::DataType::INT32) {
551545
a_crows_data = mat_a.crows().data<int32_t>();
552546
a_cols_data = mat_a.cols().data<int32_t>();
553547
} else {
554-
phi::MetaTensor crows_meta(&a_crows_int);
548+
a_crows_int = std::make_shared<DenseTensor>();
549+
a_cols_int = std::make_shared<DenseTensor>();
550+
phi::MetaTensor crows_meta(a_crows_int.get());
555551
crows_meta.set_dims(mat_a.crows().dims());
556-
phi::MetaTensor cols_meta(&a_cols_int);
552+
phi::MetaTensor cols_meta(a_cols_int.get());
557553
cols_meta.set_dims(mat_a.cols().dims());
558554

559555
phi::CastKernel<int64_t>(
560-
dev_ctx_, mat_a.crows(), phi::DataType::INT32, &a_crows_int);
556+
dev_ctx_, mat_a.crows(), phi::DataType::INT32, a_crows_int.get());
561557
phi::CastKernel<int64_t>(
562-
dev_ctx_, mat_a.cols(), phi::DataType::INT32, &a_cols_int);
558+
dev_ctx_, mat_a.cols(), phi::DataType::INT32, a_cols_int.get());
563559

564-
a_crows_data = a_crows_int.data<int32_t>();
565-
a_cols_data = a_cols_int.data<int32_t>();
560+
a_crows_data = a_crows_int->data<int32_t>();
561+
a_cols_data = a_cols_int->data<int32_t>();
566562
}
567563

568564
if (mat_b.crows().dtype() == phi::DataType::INT32) {
569565
b_crows_data = mat_b.crows().data<int32_t>();
570566
b_cols_data = mat_b.cols().data<int32_t>();
571567
} else {
572-
phi::MetaTensor crows_meta(&b_crows_int);
568+
b_crows_int = std::make_shared<DenseTensor>();
569+
b_cols_int = std::make_shared<DenseTensor>();
570+
phi::MetaTensor crows_meta(b_crows_int.get());
573571
crows_meta.set_dims(mat_b.crows().dims());
574-
phi::MetaTensor cols_meta(&b_cols_int);
572+
phi::MetaTensor cols_meta(b_cols_int.get());
575573
cols_meta.set_dims(mat_b.cols().dims());
576574

577575
phi::CastKernel<int64_t>(
578-
dev_ctx_, mat_b.crows(), phi::DataType::INT32, &b_crows_int);
576+
dev_ctx_, mat_b.crows(), phi::DataType::INT32, b_crows_int.get());
579577
phi::CastKernel<int64_t>(
580-
dev_ctx_, mat_b.cols(), phi::DataType::INT32, &b_cols_int);
578+
dev_ctx_, mat_b.cols(), phi::DataType::INT32, b_cols_int.get());
581579

582-
b_crows_data = b_crows_int.data<int32_t>();
583-
b_cols_data = b_cols_int.data<int32_t>();
580+
b_crows_data = b_crows_int->data<int32_t>();
581+
b_cols_data = b_cols_int->data<int32_t>();
584582
}
585583

586584
const T* a_values_data = mat_a.values().data<T>();
@@ -601,15 +599,15 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
601599
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx_.stream())));
602600
void* tmp_buffer_ptr = tmp_buffer->ptr();
603601

604-
GetCsrBatchNNZ<T><<<1, batch_size, 0, dev_ctx_.stream()>>>(
602+
GetCsrBatchNnz<T><<<1, batch_size, 0, dev_ctx_.stream()>>>(
605603
a_crows_data, a_rows, static_cast<int32_t*>(tmp_buffer_ptr));
606604
phi::backends::gpu::GpuMemcpyAsync(a_batch_nnz_vec.data(),
607605
tmp_buffer_ptr,
608606
batch_size * sizeof(int32_t),
609607
gpuMemcpyDeviceToHost,
610608
dev_ctx_.stream());
611609

612-
GetCsrBatchNNZ<T><<<1, batch_size, 0, dev_ctx_.stream()>>>(
610+
GetCsrBatchNnz<T><<<1, batch_size, 0, dev_ctx_.stream()>>>(
613611
b_crows_data, b_rows, static_cast<int32_t*>(tmp_buffer_ptr));
614612
phi::backends::gpu::GpuMemcpyAsync(b_batch_nnz_vec.data(),
615613
tmp_buffer_ptr,
@@ -815,8 +813,7 @@ void SparseBlas<phi::GPUContext>::SPGEMM(bool transa,
815813
*(mat_out->mutable_values()) = std::move(out_batch_values_vec[0]);
816814

817815
} else {
818-
std::vector<const DenseTensor*> cols_vec;
819-
std::vector<const DenseTensor*> values_vec;
816+
std::vector<const DenseTensor*> cols_vec, values_vec;
820817

821818
for (int i = 0; i < batch_size; ++i) {
822819
cols_vec.push_back(&out_batch_cols_vec[i]);

0 commit comments

Comments
 (0)