Skip to content

Commit 44d7755

Browse files
Fix paddle.unfold for big tensor (#74379)
1 parent 8e5cba3 commit 44d7755

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

paddle/phi/kernels/funcs/strided_copy_kernel.cu.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ __global__ void Contiguous2StridedCaseOneFunc(
3838
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
3939
phi::Array<int64_t, 6> dims,
4040
const int64_t x_max) {
41-
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
41+
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
4242
if (x < x_max) {
4343
int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x;
4444
int64_t output_offset = 0;
@@ -129,7 +129,7 @@ __global__ void Contiguous2StridedCaseOneDiffDimFunc(
129129
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
130130
phi::Array<int64_t, 6> dims,
131131
const int64_t x_max) {
132-
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
132+
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
133133
if (x < x_max) {
134134
int64_t output_offset = 0;
135135

@@ -954,8 +954,8 @@ void StrideCopyDiffDimKernel(
954954
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
955955
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_dims,
956956
int rank,
957-
int input_numel,
958-
int output_numel) {
957+
int64_t input_numel,
958+
int64_t output_numel) {
959959
if (LaunchContiguous2StridedCaseZeroKernel<T, Context>(dev_ctx,
960960
input_data,
961961
output_data,

paddle/phi/kernels/gpu/strided_copy_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ __global__ void StridedCopyCaseOneFunc(
125125
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
126126
phi::Array<int64_t, 6> dims,
127127
const int64_t x_max) {
128-
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
128+
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
129129
if (x < x_max) {
130130
int64_t input_offset = 0;
131131
int64_t output_offset = 0;
@@ -720,8 +720,8 @@ void StridedCopyKernel(const Context& dev_ctx,
720720
meta.offset = offset;
721721
out->set_meta(meta);
722722
int rank = out->dims().size();
723-
auto input_numel = input.numel();
724-
auto output_numel = out->numel();
723+
int64_t input_numel = input.numel();
724+
int64_t output_numel = out->numel();
725725
T* output_data = out->data<T>();
726726
PADDLE_ENFORCE_NOT_NULL(output_data,
727727
common::errors::InvalidArgument(

0 commit comments

Comments
 (0)