Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/phi/kernels/funcs/strided_copy_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ __global__ void Contiguous2StridedCaseOneFunc(
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
phi::Array<int64_t, 6> dims,
const int64_t x_max) {
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (x < x_max) {
int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x;
int64_t output_offset = 0;
Expand Down Expand Up @@ -129,7 +129,7 @@ __global__ void Contiguous2StridedCaseOneDiffDimFunc(
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
phi::Array<int64_t, 6> dims,
const int64_t x_max) {
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (x < x_max) {
int64_t output_offset = 0;

Expand Down Expand Up @@ -954,8 +954,8 @@ void StrideCopyDiffDimKernel(
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_dims,
int rank,
int input_numel,
int output_numel) {
int64_t input_numel,
int64_t output_numel) {
if (LaunchContiguous2StridedCaseZeroKernel<T, Context>(dev_ctx,
input_data,
output_data,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/strided_copy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ __global__ void StridedCopyCaseOneFunc(
phi::Array<int64_t, phi::DDim::kMaxRank + 1> output_stride,
phi::Array<int64_t, 6> dims,
const int64_t x_max) {
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
int64_t x = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (x < x_max) {
int64_t input_offset = 0;
int64_t output_offset = 0;
Expand Down Expand Up @@ -720,8 +720,8 @@ void StridedCopyKernel(const Context& dev_ctx,
meta.offset = offset;
out->set_meta(meta);
int rank = out->dims().size();
auto input_numel = input.numel();
auto output_numel = out->numel();
int64_t input_numel = input.numel();
int64_t output_numel = out->numel();
T* output_data = out->data<T>();
PADDLE_ENFORCE_NOT_NULL(output_data,
common::errors::InvalidArgument(
Expand Down