@@ -720,15 +720,6 @@ class IdxAndOffsetHelper {
720720 index_helper = IdxHelper<N, T>(dims);
721721 }
722722
723- template <typename U>
724- explicit IdxAndOffsetHelper (const U* dims) {
725- T temp_dims[N];
726- for (int i = 0 ; i < N; ++i) {
727- temp_dims[i] = static_cast <T>(dims[i]);
728- }
729- index_helper = IdxHelper<N, T>(temp_dims);
730- }
731-
732723 __device__ inline T IndexToOffset (const T* index) const {
733724 T offset = 0 ;
734725#pragma unroll
@@ -756,13 +747,15 @@ struct PermuteParams {
756747
757748 explicit PermuteParams (const std::vector<int >& dims,
758749 const std::vector<int >& perm_) {
759- size_t dst_dims[Rank];
760- for (size_t i = 0 ; i < Rank; ++i) {
750+ IndexT dst_dims[Rank];
751+ IndexT src_dims[Rank];
752+ for (auto i = 0 ; i < Rank; ++i) {
753+ src_dims[i] = dims[i];
761754 dst_dims[i] = dims[perm_[i]];
762755 perm[i] = perm_[i];
763756 }
764757 dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
765- src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dims. data () );
758+ src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(src_dims );
766759 }
767760};
768761
@@ -915,10 +908,9 @@ template <typename T,
915908 typename IndexT,
916909 int ReadSize,
917910 bool IsVecWrite,
918- int WritSize = IsVecWrite ? (sizeof (T) < sizeof (float )
919- ? sizeof (float ) / sizeof (T)
920- : 1 )
921- : 1 >
911+ int WritSize = (IsVecWrite && (sizeof (T) < sizeof (float )))
912+ ? sizeof (float ) / sizeof (T)
913+ : 1 >
922914__global__ void BatchTransposeKernel (const T* __restrict__ src_data,
923915 T* dst_data,
924916 IndexT rows,
@@ -1000,22 +992,20 @@ inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
1000992 const int rank = dims.size ();
1001993 IndexT num_batch = (rank == 2 ) ? 1 : dims[0 ];
1002994 IndexT rows = dims[rank - 2 ];
995+ IndexT cols = dims[rank - 1 ] / VecSize;
996+ IndexT num_tile_cols = GETTILESIZE (cols, kTileSize );
1003997
1004998 int write_size = 1 ;
1005999 bool is_write_size = sizeof (T) < sizeof (float )
10061000 ? (rows % (sizeof (float ) / sizeof (T)) ? false : true )
10071001 : false ;
10081002 if (is_write_size) {
1009- is_write_size = (num_batch * (( rows + kTileSize - 1 ) & ~( kTileSize - 1 )) /
1010- kTileSize ) >= ctx.GetSMCount ();
1003+ is_write_size = (num_batch * num_tile_cols * GETTILESIZE ( rows, kTileSize )) >
1004+ ctx.GetSMCount ();
10111005 write_size = is_write_size ? sizeof (float ) / sizeof (T) : 1 ;
10121006 }
10131007
1014- IndexT cols = dims[rank - 1 ] / VecSize;
1015- IndexT num_tile_cols = (cols + kTileSize - 1 ) / kTileSize ;
1016- IndexT num_tile_rows =
1017- (rows + kTileSize * write_size - 1 ) / (kTileSize * write_size);
1018-
1008+ IndexT num_tile_rows = GETTILESIZE (rows, (kTileSize * write_size));
10191009 dim3 blocks (num_tile_cols, num_tile_rows, num_batch);
10201010 dim3 threads (kTileSize , kBlockRows , 1 );
10211011
@@ -1174,14 +1164,15 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
11741164 phi::vectorize<int >(in.dims ()),
11751165 in.data <T>(),
11761166 out->data <T>());
1177- auto * tuner = phi::autotune::MakeTransposeTuner<T>(TransposeWithSimple <T>);
1167+ auto * tuner = phi::autotune::MakeTransposeTuner<T>(PermuteAndTranspose <T>);
11781168 tuner->AddCallBack (PermuteWithEigen<T>);
1179- tuner->AddCallBack (PermuteAndTranspose <T>);
1169+ tuner->AddCallBack (TransposeWithSimple <T>);
11801170
11811171 size_t key = phi::autotune::TransposeKey (
1182- phi::vectorize (in. dims () ),
1183- perm ,
1172+ simplifier. GetSrcDims ( ),
1173+ simplifier. GetPerm () ,
11841174 paddle::experimental::CppTypeToDataType<T>::Type ());
1175+
11851176 tuner->Run (ctx,
11861177 phi::autotune::AlgorithmType::kTranspose ,
11871178 key,
0 commit comments