Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions paddle/fluid/operators/roll_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ __global__ void RollCudaKernel(const T* input, T* output, int64_t N,

#pragma unroll Rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里写#pragma unroll就够了吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thanks

for (size_t i = 0; i < Rank; i++) {
dim_idx = (idx / strides[i]) % sizes[i];
dim_idx_shift = (dim_idx + shifts[i]) % sizes[i];
output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i];
dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名应符合实际代表的含义,这里应该是原来的dim_idx_shift,且临时变量dim_idx不再需要,应该删除。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是dim_idx_shift,就是新的dim_idx位置的预估

if (dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i];
} else {
output_idx += shifts[i] * strides[i];
}
}
output[output_idx] = input[idx];
}
Expand Down