Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions paddle/fluid/operators/math/concat_and_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num,
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int* out_cols,
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int64_t in_col, const int64_t* out_cols,
int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
Expand Down Expand Up @@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row,
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int64_t in_col, const int64_t fixed_out_col,
T** outputs_data) {
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1) {
T* outputs_data[2];
outputs_data[0] = outputs_addr0;
Expand All @@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1,
T* outputs_addr2) {
T* outputs_data[3];
Expand All @@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1,
T* outputs_addr2, T* outputs_addr3) {
T* outputs_data[4];
Expand All @@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}

static inline void GetBlockDims(const platform::CUDADeviceContext& context,
int num_rows, int num_cols, dim3* block_dims,
dim3* grid_dims) {
int64_t num_rows, int64_t num_cols,
dim3* block_dims, dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
Expand All @@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context,
*block_dims = dim3(block_cols, block_rows, 1);

int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

int grid_cols =
std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1));
int grid_rows = std::min(max_blocks / grid_cols,
std::max(num_rows / block_rows, (int64_t)1));
*grid_dims = dim3(grid_cols, grid_rows, 1);
}

Expand Down Expand Up @@ -319,22 +319,22 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
int out_row = 1;
int64_t out_row = 1;
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i];
}

int out0_col = ref_inputs[0]->numel() / out_row;
int in_col = 0, in_row = out_row;
int64_t out0_col = ref_inputs[0]->numel() / out_row;
int64_t in_col = 0, in_row = out_row;
bool has_same_shape = true;

std::vector<T*> outputs_data(o_num);
std::vector<int> outputs_cols(o_num + 1);
std::vector<int64_t> outputs_cols(o_num + 1);

outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
int t_col = ref_inputs.at(i)->numel() / out_row;
int64_t t_col = ref_inputs.at(i)->numel() / out_row;
if (has_same_shape) {
if (t_col != out0_col) has_same_shape = false;
}
Expand Down Expand Up @@ -384,13 +384,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
auto tmp_dev_ins_col_data =
memory::Alloc(context,

outputs_cols.size() * sizeof(int));
outputs_cols.size() * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()),
outputs_cols.size() * sizeof(int), context.stream());
int* dev_outs_col_data =
reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());
outputs_cols.size() * sizeof(int64_t), context.stream());
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());

SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data,
Expand Down