Skip to content

Commit 4c2d3cb

Browse files
committed
add some optimization for general_permute transpose
1 parent 225e1bb commit 4c2d3cb

File tree

4 files changed

+36
-41
lines changed

4 files changed

+36
-41
lines changed

paddle/fluid/operators/transpose_op.cu.h

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -720,15 +720,6 @@ class IdxAndOffsetHelper {
720720
index_helper = IdxHelper<N, T>(dims);
721721
}
722722

723-
template <typename U>
724-
explicit IdxAndOffsetHelper(const U* dims) {
725-
T temp_dims[N];
726-
for (int i = 0; i < N; ++i) {
727-
temp_dims[i] = static_cast<T>(dims[i]);
728-
}
729-
index_helper = IdxHelper<N, T>(temp_dims);
730-
}
731-
732723
__device__ inline T IndexToOffset(const T* index) const {
733724
T offset = 0;
734725
#pragma unroll
@@ -756,13 +747,15 @@ struct PermuteParams {
756747

757748
explicit PermuteParams(const std::vector<int>& dims,
758749
const std::vector<int>& perm_) {
759-
size_t dst_dims[Rank];
760-
for (size_t i = 0; i < Rank; ++i) {
750+
IndexT dst_dims[Rank];
751+
IndexT src_dims[Rank];
752+
for (auto i = 0; i < Rank; ++i) {
753+
src_dims[i] = dims[i];
761754
dst_dims[i] = dims[perm_[i]];
762755
perm[i] = perm_[i];
763756
}
764757
dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
765-
src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dims.data());
758+
src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(src_dims);
766759
}
767760
};
768761

@@ -915,10 +908,9 @@ template <typename T,
915908
typename IndexT,
916909
int ReadSize,
917910
bool IsVecWrite,
918-
int WritSize = IsVecWrite ? (sizeof(T) < sizeof(float)
919-
? sizeof(float) / sizeof(T)
920-
: 1)
921-
: 1>
911+
int WritSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
912+
? sizeof(float) / sizeof(T)
913+
: 1>
922914
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
923915
T* dst_data,
924916
IndexT rows,
@@ -1000,22 +992,20 @@ inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
1000992
const int rank = dims.size();
1001993
IndexT num_batch = (rank == 2) ? 1 : dims[0];
1002994
IndexT rows = dims[rank - 2];
995+
IndexT cols = dims[rank - 1] / VecSize;
996+
IndexT num_tile_cols = GETTILESIZE(cols, kTileSize);
1003997

1004998
int write_size = 1;
1005999
bool is_write_size = sizeof(T) < sizeof(float)
10061000
? (rows % (sizeof(float) / sizeof(T)) ? false : true)
10071001
: false;
10081002
if (is_write_size) {
1009-
is_write_size = (num_batch * ((rows + kTileSize - 1) & ~(kTileSize - 1)) /
1010-
kTileSize) >= ctx.GetSMCount();
1003+
is_write_size = (num_batch * num_tile_cols * GETTILESIZE(rows, kTileSize)) >
1004+
ctx.GetSMCount();
10111005
write_size = is_write_size ? sizeof(float) / sizeof(T) : 1;
10121006
}
10131007

1014-
IndexT cols = dims[rank - 1] / VecSize;
1015-
IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize;
1016-
IndexT num_tile_rows =
1017-
(rows + kTileSize * write_size - 1) / (kTileSize * write_size);
1018-
1008+
IndexT num_tile_rows = GETTILESIZE(rows, (kTileSize * write_size));
10191009
dim3 blocks(num_tile_cols, num_tile_rows, num_batch);
10201010
dim3 threads(kTileSize, kBlockRows, 1);
10211011

@@ -1174,14 +1164,15 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
11741164
phi::vectorize<int>(in.dims()),
11751165
in.data<T>(),
11761166
out->data<T>());
1177-
auto* tuner = phi::autotune::MakeTransposeTuner<T>(TransposeWithSimple<T>);
1167+
auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteAndTranspose<T>);
11781168
tuner->AddCallBack(PermuteWithEigen<T>);
1179-
tuner->AddCallBack(PermuteAndTranspose<T>);
1169+
tuner->AddCallBack(TransposeWithSimple<T>);
11801170

11811171
size_t key = phi::autotune::TransposeKey(
1182-
phi::vectorize(in.dims()),
1183-
perm,
1172+
simplifier.GetSrcDims(),
1173+
simplifier.GetPerm(),
11841174
paddle::experimental::CppTypeToDataType<T>::Type());
1175+
11851176
tuner->Run(ctx,
11861177
phi::autotune::AlgorithmType::kTranspose,
11871178
key,

paddle/fluid/operators/transpose_op.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ enum PermuteType {
7373
constexpr int kBlockRows = 16;
7474
constexpr int kTileSize = 32;
7575

76+
#define GETTILESIZE(LEN, ALIGN) ((LEN + (ALIGN - 1)) & ~(ALIGN - 1)) / ALIGN
77+
7678
// Simplify the input dims and permute dims if possible.
7779
template <typename T>
7880
class DimsSimplifier {
@@ -182,38 +184,40 @@ class DimsSimplifier {
182184
}
183185

184186
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
185-
// For gerneal_permute kernel, there is good chance for
186-
// vectorized write.
187+
// For gerneal_permute kernel, there is chance for vectorized write.
187188
type_ = PermuteType::kNormalPermute;
188189
int vec_size = phi::GetVectorizedSize<T>(dst);
189190

190-
// While the last dim is fixed, there is good chance for
191-
// both vectorized read and write.
191+
// While the last dim is fixed, there is chance for vectorized IO.
192192
if (perm_[rank_ - 1] == rank_ - 1) {
193193
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
194-
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
194+
tmp_size = GetDimVecSize(tmp_size, src_dims[rank_ - 1]);
195195
if (tmp_size > 1) {
196196
type_ = kVecPermute;
197197
vec_size = tmp_size;
198198
}
199199
}
200200

201-
// Once only transpose at the last 2 dims, there is good
202-
// chance for vectorized read.
201+
// Once only transpose at the last 2 dims.
203202
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
204203
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
205204
type_ = PermuteType::kTranspose;
206-
// With bytes limitation of shared_memory, the VecSize shall be
207-
// restricted for the type whose byte-size is less than 8 (double).
205+
// With bytes limitation of shared_memory, the VecSize
206+
// shall be restricted to sizeof(float).
208207
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
209-
vec_size =
210-
sizeof(T) > 4 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
208+
vec_size = sizeof(T) > sizeof(float)
209+
? 1
210+
: GetDimVecSize(tmp_vec, src_dims[rank_ - 1]);
211+
const int tile_size = (rank_ == 2 ? 1 : src_dims[0]) *
212+
GETTILESIZE(src_dims[rank_ - 1], kTileSize) *
213+
GETTILESIZE(src_dims[rank_ - 2], kTileSize);
214+
vec_size = tile_size < sm_count ? 1 : vec_size;
211215
}
212216
return vec_size;
213217
}
214218

215219
// To find if highest common divisor and make it as vec_size.
216-
int GetDimVesSize(const int vec_size, const size_t target_dim) {
220+
int GetDimVecSize(const int vec_size, const size_t target_dim) {
217221
int dim_vec_size = 1;
218222
for (auto size = vec_size; size > 0; size /= 2) {
219223
if (target_dim % size == 0) {

paddle/phi/kernels/autotune/cache.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
namespace phi {
2222
namespace autotune {
2323

24-
size_t TransposeKey(const std::vector<int64_t>& x_dims,
24+
size_t TransposeKey(const std::vector<int32_t>& x_dims,
2525
const std::vector<int32_t>& perm,
2626
phi::DataType dtype) {
2727
const auto rank = perm.size();

paddle/phi/kernels/autotune/cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class CudnnAlgorithmsCacheMap {
196196
int64_t cache_misses_{0};
197197
};
198198

199-
size_t TransposeKey(const std::vector<int64_t>& x_dims,
199+
size_t TransposeKey(const std::vector<int32_t>& x_dims,
200200
const std::vector<int32_t>& perm,
201201
phi::DataType dtype);
202202

0 commit comments

Comments
 (0)