@@ -38,7 +38,7 @@ __global__ void Contiguous2StridedCaseOneFunc(
3838 phi::Array<int64_t , phi::DDim::kMaxRank + 1 > output_stride,
3939 phi::Array<int64_t , 6 > dims,
4040 const int64_t x_max) {
41- int64_t x = blockIdx.x * blockDim.x + threadIdx.x ;
41+ int64_t x = static_cast < int64_t >( blockIdx.x ) * blockDim.x + threadIdx.x ;
4242 if (x < x_max) {
4343 int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y ) * x_max + x;
4444 int64_t output_offset = 0 ;
@@ -129,7 +129,7 @@ __global__ void Contiguous2StridedCaseOneDiffDimFunc(
129129 phi::Array<int64_t , phi::DDim::kMaxRank + 1 > output_stride,
130130 phi::Array<int64_t , 6 > dims,
131131 const int64_t x_max) {
132- int64_t x = blockIdx.x * blockDim.x + threadIdx.x ;
132+ int64_t x = static_cast < int64_t >( blockIdx.x ) * blockDim.x + threadIdx.x ;
133133 if (x < x_max) {
134134 int64_t output_offset = 0 ;
135135
@@ -954,8 +954,8 @@ void StrideCopyDiffDimKernel(
954954 const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_stride,
955955 const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_dims,
956956 int rank,
957- int input_numel,
958- int output_numel) {
957+ int64_t input_numel,
958+ int64_t output_numel) {
959959 if (LaunchContiguous2StridedCaseZeroKernel<T, Context>(dev_ctx,
960960 input_data,
961961 output_data,
0 commit comments