diff --git a/paddle/phi/kernels/funcs/transpose_function.cu.h b/paddle/phi/kernels/funcs/transpose_function.cu.h index 2cc8dd2d361e41..59daa0b8d73c89 100644 --- a/paddle/phi/kernels/funcs/transpose_function.cu.h +++ b/paddle/phi/kernels/funcs/transpose_function.cu.h @@ -557,12 +557,160 @@ __global__ void TransposeSimpleKernel(IndexType nthreads, } } +typedef struct alignas(8) fp8x8_t { + union data_t { + phi::float8_e4m3fn scalar[8]; + uint2 vector; + }; + data_t data; + + __device__ __forceinline__ void load(const void* ptr) { + data = *reinterpret_cast(ptr); + } + + __device__ __forceinline__ void store(void* ptr) const { + *reinterpret_cast(ptr) = data; + } +} fp8x8_t; + +constexpr int kVecSize = 8; +constexpr int BLOCK_DIM = 16; +constexpr int BLOCK_TILE_SIZE = 128; +constexpr int BLOCK_TILE_WIDTH = BLOCK_TILE_SIZE; +constexpr int BLOCK_TILE_HEIGHT = BLOCK_TILE_SIZE; +constexpr int THREAD_TILE_DIM = BLOCK_TILE_SIZE / BLOCK_DIM; + +__global__ void +__launch_bounds__(BLOCK_DIM* BLOCK_DIM) inline fp8_fast_transpose_kernel( + const phi::float8_e4m3fn* __restrict__ src, // Source matrix (M x N) + phi::float8_e4m3fn* __restrict__ dst, // Destination matrix (N x M) + int B, + int M, + int N, // Batch size, M-dimension, N-dimension + size_t batch_stride) { // Stride between batches in global memory (M*N + // elements) + // Shared memory tile with padding to avoid bank conflicts, padding instead of + // swizzle for better performance + __shared__ __align__(1024) + fp8x8_t smem[BLOCK_TILE_HEIGHT][BLOCK_TILE_WIDTH / kVecSize + 1]; + + // Thread-local storage: 8 fp8x8_t units, effectively an 8x8 block of fp8_t + // values. + fp8x8_t local_tile[kVecSize]; + fp8x8_t local_tile_transposed[kVecSize]; + + // Thread indices within the block (0-15 for x and y, since 16x16 = 256 + // threads) + const uint32_t tid_x = threadIdx.x; // Column-wise thread index (0-15) + const uint32_t tid_y = threadIdx.y; // Row-wise thread index (0-15) + + // Block indices within the grid + const uint32_t block_x = blockIdx.x; // Tile index along N-dimension + const uint32_t block_y = blockIdx.y; // Tile index along M-dimension + const uint32_t block_z = blockIdx.z; // Batch index + + // Calculate global offsets for the current block's tile in the M x N source + // matrix + const uint32_t global_m_offset = + block_y * BLOCK_TILE_HEIGHT; // Starting M index for this block + const uint32_t global_n_offset = + block_x * BLOCK_TILE_WIDTH; // Starting N index for this block + + const size_t current_batch_offset = + static_cast(batch_stride) * block_z; + +// 1. Load src into register in uint2 vectorized manner. +#pragma unroll + for (uint32_t k = 0; k < THREAD_TILE_DIM; + ++k) { // Iterate 8 times for the 8 rows in the thread's block + const uint32_t src_global_row = + global_m_offset + tid_y * THREAD_TILE_DIM + k; + const uint32_t src_global_col_start = + global_n_offset + tid_x * THREAD_TILE_DIM; + + // Check bounds for source matrix before loading + // THREAD_TILE_DIM (8) is the width of the fp8x8_t block. + const phi::float8_e4m3fn* src_ptr = + src + current_batch_offset + static_cast(src_global_row) * N + + src_global_col_start; + local_tile[k].load(src_ptr); + } + +// 2. Transpose local_tile in register level. +#pragma unroll + for (uint32_t k_row = 0; k_row < THREAD_TILE_DIM; ++k_row) { +#pragma unroll + for (uint32_t k_col = 0; k_col < THREAD_TILE_DIM; ++k_col) { + local_tile_transposed[k_col].data.scalar[k_row] = + local_tile[k_row].data.scalar[k_col]; + } + } + +// 3. Store transposed data to shared memory +#pragma unroll + for (uint32_t k = 0; k < THREAD_TILE_DIM; ++k) { + const uint32_t smem_row = tid_x * THREAD_TILE_DIM + k; + const uint32_t smem_col_start = tid_y * THREAD_TILE_DIM / 8; // = tid_y + smem[smem_row][smem_col_start] = local_tile_transposed[k]; + } + + __syncthreads(); + +// 4. Store from shared memory to dst in uint2 vectorized manner. +#pragma unroll + for (uint32_t k = 0; k < THREAD_TILE_DIM; ++k) { + const uint32_t dst_global_row = + global_n_offset + tid_y * THREAD_TILE_DIM + k; + const uint32_t dst_global_col_start = + global_m_offset + tid_x * THREAD_TILE_DIM; + + size_t offset = current_batch_offset + + static_cast(dst_global_row) * M + + dst_global_col_start; + phi::float8_e4m3fn* dst_ptr = dst + offset; + + fp8x8_t output_block; + const uint32_t smem_row = tid_y * THREAD_TILE_DIM + k; + const uint32_t smem_col = tid_x * THREAD_TILE_DIM / kVecSize; // = tid_x + output_block = smem[smem_row][smem_col]; + output_block.store(dst_ptr); + } +} + +template +void dispatch_fp8_fast_transpose_kernel(const phi::GPUContext& d, + const T* input, + const uint32_t B, + const uint32_t M, + const uint32_t N, + T* output) { + dim3 grid, block; + block.x = BLOCK_DIM; // 256 threads per block + block.y = BLOCK_DIM; + + grid.z = B; + grid.y = M / BLOCK_TILE_SIZE; // not for un-aligned + grid.x = N / BLOCK_TILE_SIZE; // not for un-aligned + + fp8_fast_transpose_kernel<<>>( + input, output, B, M, N, static_cast(M) * static_cast(N)); +} + // Here suppose convert all tensor to dim3, so just change dim1 and 2. template void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, const Dim3& input_dims, T* output) { + // FP8 fast path + if constexpr (std::is_same::value) { + if (input_dims[1] >= 128 && input_dims[2] >= 128 && + input_dims[1] % 128 == 0 && input_dims[2] % 128 == 0) { + dispatch_fp8_fast_transpose_kernel( + d, input, input_dims[0], input_dims[1], input_dims[2], output); + return; + } + } // Suppose tile size > 16 static const int kMinTileSize = 16; static const int kMinNarrowTileSize = 96; diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index f5ef7e3cf6f6e9..993e7fe59df9d4 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -224,6 +224,40 @@ def test_check_grad(self): ) +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 9.0, + "core is not compiled with CUDA or not support native fp8", +) +class TestFP8FastTranspose(unittest.TestCase): + def setUp(self): + self.dtype = paddle.float8_e4m3fn + self.test_cases = [ + {"shape": (7168, 16384), "perm": [1, 0], "name": "2D(7168,16384)"}, + { + "shape": (8, 7168, 4096), + "perm": [0, 2, 1], + "name": "3D(8,7168,4096)", + }, + { + "shape": (8, 2048, 7168), + "perm": [0, 2, 1], + "name": "3D(8,2048,7168)", + }, + ] + + def test_verify_transpose(self): + paddle.disable_static() + with paddle.no_grad(): + for case in self.test_cases: + x = paddle.randn(case["shape"]).cast(self.dtype) + np_data = x.numpy() + gold = np.transpose(np_data, case["perm"]) + out = paddle.transpose(x, case["perm"]).contiguous() + np.testing.assert_equal(out.numpy(), gold) + paddle.enable_static() + + class TestAutoTuneTransposeFP16Op(OpTest): def setUp(self): self.init_op_type()