From c742a0ece78a547d02a70b951d45b0c9e33ac447 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 01:06:10 +0000 Subject: [PATCH 01/39] add files from fp6_llm --- torchao/csrc/cuda/fp6_llm/configs.h | 74 ++++++ torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 220 ++++++++++++++++++ torchao/csrc/cuda/fp6_llm/fp6_linear.cuh | 65 ++++++ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 172 ++++++++++++++ .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 47 ++++ torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 59 +++++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 113 +++++++++ torchao/csrc/cuda/fp6_llm/utils_core.cuh | 200 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 75 ++++++ .../cuda/fp6_llm/utils_parallel_dequant.cuh | 111 +++++++++ torchao/csrc/cuda/fp6_llm/weight_dequant.h | 40 ++++ torchao/csrc/cuda/fp6_llm/weight_prepacking.h | 171 ++++++++++++++ torchao/csrc/cuda/fp6_llm/weight_quant.h | 105 +++++++++ torchao/csrc/fp6_llm.cpp | 10 + 14 files changed, 1462 insertions(+) create mode 100644 torchao/csrc/cuda/fp6_llm/configs.h create mode 100644 torchao/csrc/cuda/fp6_llm/fp6_linear.cu create mode 100644 torchao/csrc/cuda/fp6_llm/fp6_linear.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/ptx_mma.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_core.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_gmem.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/weight_dequant.h create mode 100644 torchao/csrc/cuda/fp6_llm/weight_prepacking.h create mode 100644 torchao/csrc/cuda/fp6_llm/weight_quant.h create mode 100644 torchao/csrc/fp6_llm.cpp diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h new file mode 100644 index 0000000000..e6b217cdca --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/configs.h @@ -0,0 +1,74 @@ +#ifndef CONFIGS_H +#define CONFIGS_H + +//#define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +/************************ General Config for FP6-LLM **********************/ +#define WEIGHT_FRAG1_BIT_WIDTH 2 +#define WEIGHT_FRAG2_BIT_WIDTH 4 +#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6 +//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256 +/*************************** 64*64 Weghts of A WARP *************************/ +#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64 +#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. +/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ +#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64 +#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128 +/******************** Register Allocation For QUANTIZED DATA ******************/ +#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128 +#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8 +#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16 +/******************** Register Allocation For QUANT Scales ******************/ +#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers +#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread + + + +#endif // CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu new file mode 100644 index 0000000000..01bb114c16 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -0,0 +1,220 @@ +#include "kernel_matmul.cuh" +#include "kernel_reduction.cuh" +#include "weight_prepacking.h" +#include "weight_dequant.h" +#include "weight_quant.h" + +#include +#include + +template +static void Kernel_Ex(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + OutputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ + #ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); + #endif + static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); + // + #ifdef DEBUG_MODE + printf("GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", + GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); + printf("\n"); + #endif + QUANT_GEMM_Kernel<<>> + (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); +} + +/* + * + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + half *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K) +{ + assert(M_Global % 256 == 0); + assert(K_Global % 64 == 0); + assert(N_Global>0); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if(N_Global>0 && N_Global<=8) N_PowerOf2 = 8; + if(N_Global>8 && N_Global<=16) N_PowerOf2 = 16; + if(N_Global>16 && N_Global<=32) N_PowerOf2 = 32; + if(N_Global>32 && N_Global<=64) N_PowerOf2 = 64; + if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; + if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + } + } + else { + switch (N_PowerOf2) { + case 8: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} + + + + + +#ifndef NO_PYTORCH +#include +#include + +namespace torchao { +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathmatical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. +After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK=1) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + assert( num_in_channels%64 == 0 ); + assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + + fp6_linear_kernel(0, // Using default stream here. + weight, + scales, + in_feats, + out_feats, + M, + N, + K, + Reduction_Workspace, + splitK); + + return _out_feats; +} + + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor) +{ + size_t OC = fp6_tensor.size(0); + size_t IC = fp6_tensor.size(1); + assert (IC%3==0); + IC = IC*16/3; + assert( (OC%256==0) && (IC%64==0) ); + auto packed_tensor = torch::empty_like(fp6_tensor); + auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); + return packed_tensor; +} + +/* + * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. + * A useful tool to construct input matrices for the FP16 GEMM baseline. + * [Input] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. + * [Output] + * fp16_tensor: half tensor of shape [OC, IC]. + */ +torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale) +{ + int OC = fp6_tensor.size(0); + assert(fp6_tensor.size(1) % 3 == 0); + int IC = fp6_tensor.size(1) / 3 * 16; + assert(fp16_scale.size(0)==OC); + // + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); + // + auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); + at::Tensor fp16_tensor = torch::empty({OC, IC}, options); + auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); + // + DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr); + // + return fp16_tensor; +} +} +#endif diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh b/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh new file mode 100644 index 0000000000..3bf3ee1dc4 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh @@ -0,0 +1,65 @@ +#include +#include +#include + +/* +* Computes FP6-FP16 GEMM (C++ interface). +*/ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + half *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K); + +/* + * In-place weight prepacking (C++ interface). + */ +void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K); + +/* + * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. + */ +void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale); + +#ifndef NO_PYTORCH +#include +#include + +namespace torchao { +/* +* Computes FP6-FP16 GEMM (PyTorch interface). +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK=1); + +/* + * Weight prepacking (Pytorch interface). + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor); + +/* + * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. + * A useful tool to construct input matrices for the FP16 GEMM baseline. + * [Input] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. + * [Output] + * fp16_tensor: half tensor of shape [OC, IC]. + */ +torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale); + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda); + m.impl("torchao::fp6_weight_prepacking_cpu", &weight_matrix_prepacking_cpu); + m.impl("torchao::fp6_weight_dequant_cpu", &weight_matrix_dequant_cpu); +} + +} +#endif diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh new file mode 100644 index 0000000000..1a971f837f --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -0,0 +1,172 @@ +#include "configs.h" +#include "utils_gmem.cuh" +#include "utils_core.cuh" + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ + template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ + #ifdef DEBUG_MODE + assert(K_Global%TilingConfig::TILE_K==0); + assert(M_Global%TilingConfig::TILE_M==0); + assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); + #endif + // 2+4 weight split + const uint4* Weight1 = Weight; + const uint4* Weight2 = Weight1 + M_Global*K_Global*2/128; + // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + extern __shared__ __align__(128) half smem[]; + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K/Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if(BatchID(smem); + uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; + // Pre-fetch of A tile + for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; + } + // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// + const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef PIPELINE_LEVEL_SMEM + uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) +#endif + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); +#endif + // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + #pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) + { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +#ifdef PIPELINE_LEVEL_SMEM + uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; + uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; +#endif + uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#ifdef PIPELINE_LEVEL_SMEM + half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); + #ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + #else + PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + #endif + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast (smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; + for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + } +} diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh new file mode 100644 index 0000000000..442de103b8 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -0,0 +1,47 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Used for the reduction of result matrix if Split-K is used +// Reduction_Workspace: (Split_K, M_Global, N_Global), column major +// C: (M_Global, N_Global), column major +// Each thread deals with 8 output elements, each elements is the sum of Split_K elements +// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp +// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp +// GridSize = (M_Global*N_Global) / 256 + +#include +#include +#include + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +{ + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { + #pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } + // Writing to global memory + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh new file mode 100644 index 0000000000..8a6069ff1f --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -0,0 +1,59 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Extended from CUTLASS's source code + +#ifndef PTX_CP_ASYNC_CUH +#define PTX_CP_ASYNC_CUH + +#include +#include +#include + +template +__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) +{ + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile("{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +__device__ __forceinline__ void cp_async_group_commit() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +__device__ __forceinline__ void cp_async_wait_group() +{ + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() +{ + asm volatile("cp.async.wait_all;\n" ::); +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh new file mode 100644 index 0000000000..1920678244 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -0,0 +1,113 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#ifndef PTX_MMA_CUH +#define PTX_MMA_CUH + +#include +#include +#include + +#include +#include "configs.h" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int slice_id) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#else +// Debug: Whether ldmatrix.trans is required??? +// B is in column-major +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int k_offset) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#endif + +__device__ __forceinline__ void +MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) +{ + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh new file mode 100644 index 0000000000..e0d374f22f --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -0,0 +1,200 @@ +#ifndef UTILS_CORE_CUH +#define UTILS_CORE_CUH + +#include + +#include "configs.h" +#include "ptx_mma.cuh" +#include "utils_parallel_dequant.cuh" + + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { + SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) +{ + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers +} + +template +__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t (*a_read )[4] = a; + uint32_t (*a_write)[4] = a; + uint32_t (*b_read )[4] = b; + uint32_t (*b_write)[4] = b; + if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} + else { b_read += NumRegSets_b; a_read += NumRegSets_a;} + + // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); + } + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +#else +// Old version with naive pipeline design +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + + // Reigsters to store FP32 results + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2*2]; // double buffer is used + uint32_t a_2[4*2]; // double buffer is used + // Registers to store decompressed FP6 + uint32_t a [NumRegSets_a * 1][4]; // No double buffer + // Register to store FP16 B matrix (a slice) + uint32_t b [NumRegSets_b * 2][4]; // double buffer is used + + // Overlapped Smem and TC pipeline: pre-loading from shared to registers + CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); + B_FromSharedToReg (b, read_SPTR, 0); + + #pragma unroll + for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { + uint32_t (*b_read)[4] = b; + uint32_t (*b_write)[4] = b; + uint32_t *a_1_read = a_1; + uint32_t *a_1_write = a_1; + uint32_t *a_2_read = a_2; + uint32_t *a_2_write = a_2; + if(k%2==0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } + else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; + } + // data loading + if (k + 1 < WARP_K_MMA_TENSORS) { + // updating SPTR for fragment1 and fragment2 + read_SPTR_Frag1 += 2*WARP_SIZE; + read_SPTR_Frag2 += 4*WARP_SIZE; + CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); + B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); + } + // SIMT Dequant + Tensor Core computations + Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) + MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + } +} +#endif // #ifdef PIPELINE_LEVEL_SMEM + +template +__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) +{ + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; + int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; + #pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r%2==1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh new file mode 100644 index 0000000000..86ba333b68 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -0,0 +1,75 @@ +#ifndef UTILS_GMEM_CUH +#define UTILS_GMEM_CUH + +#include +#include "configs.h" +#include "ptx_cp.async.cuh" + +/* + * Copying A1/A2 from global memory to shared memory. + * Usually 1024 or 2048 Bytes + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) { + #ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); + #endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id*8; + GPTR_HALF += lane_id*8; + #pragma unroll + for(int i=0; i( SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } + +} + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id*2; + int Offset_Global = lane_id/4 + (lane_id%4)*16; + for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; +} + +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) { + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x%8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; + #pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh new file mode 100644 index 0000000000..5a6977bc07 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -0,0 +1,111 @@ +#ifndef UTILS_PARALLELDEQUANT_CUH +#define UTILS_PARALLELDEQUANT_CUH + +#include +#include +#include + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { + *R2 = *R1 & 0x80808080; + *R1 = *R1 >> 2; + *R1 = *R1 & 0x1f1f1f1f; + *R2 = *R2 | *R1; + *R1 = *R2 & 0x9f009f00; + *R2 = *R2 & 0x009f009f; + *R2 = *R2 << 8; +} + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is NOT applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { + //*R2 = *R1 & 0x80808080; + *R2 = *R1 & 0xc0c0c0c0; + *R1 = *R1 >> 2; + //*R1 = *R1 & 0x1f1f1f1f; + *R1 = *R1 & 0x0f0f0f0f; + *R2 = *R2 | *R1; + // + //*R1 = *R2 & 0x9f009f00; + //*R2 = *R2 & 0x009f009f; + *R1 = *R2 & 0xcf00cf00; + if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; + if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; + *R2 = *R2 & 0x00cf00cf; + if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; + if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + // + *R2 = *R2 << 8; + //*R1 = 0x3c003c00; + //*R2 = 0x3c003c00; +} + +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); + return output; +} + +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__ *read_RPTR_Frag1, + u_int32_t __restrict__ *read_RPTR_Frag2, + u_int32_t *Scales) { + u_int32_t *OutputRegs = reinterpret_cast (Reg); + u_int32_t *Frag1_PTR = read_RPTR_Frag1; + u_int32_t *Frag2_PTR = read_RPTR_Frag2; + half *Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; + // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 + #pragma unroll(8) + for(int i=0; i<8; i++) { + // Frag1 + Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; + if(i%4==3) Frag1_PTR++; + else (*Frag1_PTR) = (*Frag1_PTR) << 2; + // Frag2 + tmp = (*Frag2_PTR) & 0xf0f0f0f0; + tmp = tmp >> 2; + if(i%2==1) Frag2_PTR++; + else (*Frag2_PTR) = (*Frag2_PTR) << 4; + // Packed_FP6 + Packed_FP6 = Packed_FP6 | tmp; + // + FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + // + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if(i%2==1) Scale_RPTR += 2; + } + +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; + #pragma unroll + for(int i=0; i<4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/weight_dequant.h b/torchao/csrc/cuda/fp6_llm/weight_dequant.h new file mode 100644 index 0000000000..2451a55d78 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/weight_dequant.h @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include +#include + +void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { + assert(M%64==0); // Currently, M must be a multiple of 64. + assert(K%64==0); // Currently, K must be a multiple of 64. + size_t TotalSizeInByte = M*K*6/8; + // + half* OutPTR = A_16bit_h; + for(size_t i=0; i>2)&0x1f); + unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); + B2 = (B2&0x80) | ((B2>>2)&0x1f); + unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); + B3 = (B3&0x80) | ((B3>>2)&0x1f); + unsigned char B4 = A_6bit_h[i*3+2]<<2; + B4 = (B4&0x80) | ((B4>>2)&0x1f); + half FP1, FP2, FP3, FP4; + unsigned char *PTR1, *PTR2, *PTR3, *PTR4; + PTR1 = reinterpret_cast(&FP1); + PTR2 = reinterpret_cast(&FP2); + PTR3 = reinterpret_cast(&FP3); + PTR4 = reinterpret_cast(&FP4); + PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU + PTR2[0] = 0; PTR2[1] = B2; + PTR3[0] = 0; PTR3[1] = B3; + PTR4[0] = 0; PTR4[1] = B4; + OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); + // + OutPTR +=4; + } +} diff --git a/torchao/csrc/cuda/fp6_llm/weight_prepacking.h b/torchao/csrc/cuda/fp6_llm/weight_prepacking.h new file mode 100644 index 0000000000..b33bdb18a9 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/weight_prepacking.h @@ -0,0 +1,171 @@ +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); + Padded_FP6[3] = FP6_Array[2]<<2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); + Padded_FP6[7] = FP6_Array[5]<<2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) +{ + unsigned char out; + out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for(int t=0; t<4; t++) + { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2], Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2], Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); + Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); + Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + } + // + for(int t=0; t<4; t++) + { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM + int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for(int i=0; i<16; i++) + Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<16; i++) + output |= ( Frags_2bit[i] >> (i*2) ); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM + int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for(int i=0; i<8; i++) + Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<8; i++) + output |= ( Frags_4bit[i] >> (i*4) ); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = reinterpret_cast(packed_weights); + unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K*6/8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) + { + for(size_t k=0; k<64/16; k++) + { + size_t row = i*64 + k*16; + size_t col = j*16; + unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; + unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; + unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; + // Dealing with each 16*16 blocks then... + for(int l=0; l<8; l++) Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], &A_Segment_4bit[l*4], StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M*K*2/8/32; + size_t BytesPerThread_4bit = M*K*4/8/32; + for(int i=0; i<32; i++) + { + assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); + assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for(size_t i=0; i + +/* + * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. + */ +void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + // Constants for FP6 + constexpr int exponent_nbits_fp6 = 3; + constexpr int mantissa_nbits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + // Constants for FP16 + constexpr int exponent_nbits_fp16 = 5; + constexpr int mantissa_nbits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; + + int fp6_temp[4]; + + float absmin_nonzero_fp6 = 0.0625; + // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is + // the same with that in qtorch. + float absmax_fp6 = 28; + + for (int i = 0; i < 4; ++i) { + uint16_t source = FP16x4[i]; + float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); + if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || + fp6_value_abs > absmax_fp6) { + // TODO(zhen): a better way may be rounding it to the nearest FP6 value. + throw std::invalid_argument("Input value out of range for FP6."); + } + + // It is not safe to do shift operation on uint16_t. So we promote it to int. + int source_promote = int(source); + + int sign_bit = (source_promote >> 15); + // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' + int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + // Extracting mantissa represented in FP16 + int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp_bit; + int new_mant_bit; + + if (exp_bit == 0) { + // Subnormal FP16 number. Too small for FP6. + new_exp_bit = 0; + new_mant_bit = 0; + } else { + new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; + + // Deal with subnormal FP6 values. + int target_exp_val = exp_bit - exp_bias_fp16; + int min_fp6_exp_val = -exp_bias_fp6 + 1; + bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; + if (subnormal_fp6) { + // TODO(zhen): add the rounding logic. + new_exp_bit = 0; + // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we + // need to add it + new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> + (min_fp6_exp_val - target_exp_val); + } + } + + fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Function to prepack FP16 weights into continuous FP6 values. + * + * Parameters: + * weight_16bit: input weight in FP16, size M*K + * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 + * M, K: the shape of the weight + */ +void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) +{ + // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). + if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } + size_t K_fp6_packed = K * 6 / 8; + // #pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; + uint16_t* ptr_16bit = weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + cast_fp16_fp6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp new file mode 100644 index 0000000000..1d2f31e2fe --- /dev/null +++ b/torchao/csrc/fp6_llm.cpp @@ -0,0 +1,10 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("fp16act_fp6weight_linear(Tensor _weights, Tensor _scales, int splitK = 1) -> Tensor"); + m.def("fp6_weight_prepacking_cpu(Tensor fp6_tensor) -> Tensor"); + m.def("fp6_weight_dequant_cpu(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); +} From 4eb8be67b4919829f27f4f6a1818453b7d27ab8e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 09:52:05 +0000 Subject: [PATCH 02/39] try to port weight packing first --- setup.py | 1 + torchao/csrc/fp6_llm.cpp | 10 -- torchao/csrc/fp6_test/fp6_llm.cpp | 8 + torchao/csrc/fp6_test/weight_packing.cpp | 204 +++++++++++++++++++++++ 4 files changed, 213 insertions(+), 10 deletions(-) delete mode 100644 torchao/csrc/fp6_llm.cpp create mode 100644 torchao/csrc/fp6_test/fp6_llm.cpp create mode 100644 torchao/csrc/fp6_test/weight_packing.cpp diff --git a/setup.py b/setup.py index 3972cb2c76..771ff480a1 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ def get_extensions(): this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + sources += list(glob.glob(os.path.join(extensions_dir, "fp6_test/*.cpp"))) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp deleted file mode 100644 index 1d2f31e2fe..0000000000 --- a/torchao/csrc/fp6_llm.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include -#include - -TORCH_LIBRARY_FRAGMENT(torchao, m) { - m.impl_abstract_pystub("torchao.ops"); - m.def("fp16act_fp6weight_linear(Tensor _weights, Tensor _scales, int splitK = 1) -> Tensor"); - m.def("fp6_weight_prepacking_cpu(Tensor fp6_tensor) -> Tensor"); - m.def("fp6_weight_dequant_cpu(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); -} diff --git a/torchao/csrc/fp6_test/fp6_llm.cpp b/torchao/csrc/fp6_test/fp6_llm.cpp new file mode 100644 index 0000000000..c204d105bd --- /dev/null +++ b/torchao/csrc/fp6_test/fp6_llm.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("weight_matrix_prepacking_cpu(Tensor fp6_tensor) -> Tensor"); +} diff --git a/torchao/csrc/fp6_test/weight_packing.cpp b/torchao/csrc/fp6_test/weight_packing.cpp new file mode 100644 index 0000000000..5cf926096a --- /dev/null +++ b/torchao/csrc/fp6_test/weight_packing.cpp @@ -0,0 +1,204 @@ +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); + Padded_FP6[3] = FP6_Array[2]<<2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); + Padded_FP6[7] = FP6_Array[5]<<2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) +{ + unsigned char out; + out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for(int t=0; t<4; t++) + { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2], Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2], Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); + Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); + Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + } + // + for(int t=0; t<4; t++) + { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM + int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for(int i=0; i<16; i++) + Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<16; i++) + output |= ( Frags_2bit[i] >> (i*2) ); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM + int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for(int i=0; i<8; i++) + Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<8; i++) + output |= ( Frags_4bit[i] >> (i*4) ); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = reinterpret_cast(packed_weights); + unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K*6/8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) + { + for(size_t k=0; k<64/16; k++) + { + size_t row = i*64 + k*16; + size_t col = j*16; + unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; + unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; + unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; + // Dealing with each 16*16 blocks then... + for(int l=0; l<8; l++) Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], &A_Segment_4bit[l*4], StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M*K*2/8/32; + size_t BytesPerThread_4bit = M*K*4/8/32; + for(int i=0; i<32; i++) + { + assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); + assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for(size_t i=0; i +#include + +namespace torchao { + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) +{ + size_t OC = fp6_tensor.size(0); + size_t IC = fp6_tensor.size(1); + TORCH_CHECK(IC % 3 == 0); + IC = IC * 16 / 3; + TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0)); + auto packed_tensor = at::empty_like(fp6_tensor); + auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); + return packed_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::weight_matrix_prepacking_cpu", &weight_matrix_prepacking_cpu); +} + +} From 8608664bb864cfb5b2b7915309ff3d7a50bee5b5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 09:53:01 +0000 Subject: [PATCH 03/39] rename --- .../csrc/fp6_test/{weight_packing.cpp => weight_prepacking.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torchao/csrc/fp6_test/{weight_packing.cpp => weight_prepacking.cpp} (100%) diff --git a/torchao/csrc/fp6_test/weight_packing.cpp b/torchao/csrc/fp6_test/weight_prepacking.cpp similarity index 100% rename from torchao/csrc/fp6_test/weight_packing.cpp rename to torchao/csrc/fp6_test/weight_prepacking.cpp From b7c7b2876e789b8117812a1954a29a686b1c6895 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 20:39:07 +0800 Subject: [PATCH 04/39] rename fp6 weight packing --- test/test_ops.py | 12 ++++++++++++ torchao/csrc/fp6_test/fp6_llm.cpp | 2 +- torchao/csrc/fp6_test/weight_prepacking.cpp | 2 +- torchao/ops.py | 13 +++++++++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a569f24799..6f46a46bef 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -42,6 +42,18 @@ def test_nms(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) + def test_prepack_fp6_weight(self): + OC = 256 + IC = 256 + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + + # smoke test + torchao.ops.prepack_fp6_weight(fp6_weight) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) + if __name__ == "__main__": unittest.main() diff --git a/torchao/csrc/fp6_test/fp6_llm.cpp b/torchao/csrc/fp6_test/fp6_llm.cpp index c204d105bd..6e848c785b 100644 --- a/torchao/csrc/fp6_test/fp6_llm.cpp +++ b/torchao/csrc/fp6_test/fp6_llm.cpp @@ -4,5 +4,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("weight_matrix_prepacking_cpu(Tensor fp6_tensor) -> Tensor"); + m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); } diff --git a/torchao/csrc/fp6_test/weight_prepacking.cpp b/torchao/csrc/fp6_test/weight_prepacking.cpp index 5cf926096a..2b03eb32ab 100644 --- a/torchao/csrc/fp6_test/weight_prepacking.cpp +++ b/torchao/csrc/fp6_test/weight_prepacking.cpp @@ -198,7 +198,7 @@ at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) } TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::weight_matrix_prepacking_cpu", &weight_matrix_prepacking_cpu); + m.impl("torchao::prepack_fp6_weight", &weight_matrix_prepacking_cpu); } } diff --git a/torchao/ops.py b/torchao/ops.py index 0931d32026..08253c5f59 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -21,3 +21,16 @@ def _(dets, scores, iou_threshold): ctx = torch._custom_ops.get_ctx() num_to_keep = ctx.create_unbacked_symint() return dets.new_empty(num_to_keep, dtype=torch.long) + + +def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: + return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) + + +@torch.library.impl_abstract("torchao::prepack_fp6_weight") +def _(fp6_weight): + torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {dets.dim()}D") + # ctx = torch._custom_ops.get_ctx() + # num_to_keep = ctx.create_unbacked_symint() + # return fp6_weight.new_empty(num_to_keep, dtype=torch.long) + return torch.empty_like(fp6_weight) From 3c9aac78c559819eee4584e702b1d34c5390d74d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 21:19:26 +0800 Subject: [PATCH 05/39] add fp16act_fp6weight_linear --- setup.py | 1 + torchao/csrc/cuda/fp6_llm/fp6_linear.cuh | 65 ------------------- .../csrc/cuda/{fp6_llm => fp6_test}/configs.h | 0 .../cuda/{fp6_llm => fp6_test}/fp6_linear.cu | 62 ++---------------- .../{fp6_llm => fp6_test}/kernel_matmul.cuh | 0 .../kernel_reduction.cuh | 0 .../{fp6_llm => fp6_test}/ptx_cp.async.cuh | 0 .../cuda/{fp6_llm => fp6_test}/ptx_mma.cuh | 0 .../cuda/{fp6_llm => fp6_test}/utils_core.cuh | 0 .../cuda/{fp6_llm => fp6_test}/utils_gmem.cuh | 0 .../utils_parallel_dequant.cuh | 0 torchao/csrc/fp6_test/fp6_llm.cpp | 1 + torchao/ops.py | 4 ++ 13 files changed, 11 insertions(+), 122 deletions(-) delete mode 100644 torchao/csrc/cuda/fp6_llm/fp6_linear.cuh rename torchao/csrc/cuda/{fp6_llm => fp6_test}/configs.h (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/fp6_linear.cu (79%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/kernel_matmul.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/kernel_reduction.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/ptx_cp.async.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/ptx_mma.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/utils_core.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/utils_gmem.cuh (100%) rename torchao/csrc/cuda/{fp6_llm => fp6_test}/utils_parallel_dequant.cuh (100%) diff --git a/setup.py b/setup.py index 771ff480a1..cf0badfb67 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + cuda_sources += list(glob.glob(os.path.join(extensions_cuda_dir, "fp6_test/*.cu"))) if use_cuda: sources += cuda_sources diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh b/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh deleted file mode 100644 index 3bf3ee1dc4..0000000000 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cuh +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include -#include - -/* -* Computes FP6-FP16 GEMM (C++ interface). -*/ -cudaError_t fp6_linear_kernel(cudaStream_t stream, - const uint4 *Weight, - const half *Scales, - const half *B, - half *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - int Split_K); - -/* - * In-place weight prepacking (C++ interface). - */ -void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K); - -/* - * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. - */ -void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale); - -#ifndef NO_PYTORCH -#include -#include - -namespace torchao { -/* -* Computes FP6-FP16 GEMM (PyTorch interface). -*/ -torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _weights, - torch::Tensor _scales, - int splitK=1); - -/* - * Weight prepacking (Pytorch interface). - */ -torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor); - -/* - * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. - * A useful tool to construct input matrices for the FP16 GEMM baseline. - * [Input] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. - * [Output] - * fp16_tensor: half tensor of shape [OC, IC]. - */ -torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale); - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda); - m.impl("torchao::fp6_weight_prepacking_cpu", &weight_matrix_prepacking_cpu); - m.impl("torchao::fp6_weight_dequant_cpu", &weight_matrix_dequant_cpu); -} - -} -#endif diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_test/configs.h similarity index 100% rename from torchao/csrc/cuda/fp6_llm/configs.h rename to torchao/csrc/cuda/fp6_test/configs.h diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_test/fp6_linear.cu similarity index 79% rename from torchao/csrc/cuda/fp6_llm/fp6_linear.cu rename to torchao/csrc/cuda/fp6_test/fp6_linear.cu index 01bb114c16..e9f051c2c9 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_test/fp6_linear.cu @@ -1,8 +1,5 @@ #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" -#include "weight_prepacking.h" -#include "weight_dequant.h" -#include "weight_quant.h" #include #include @@ -103,12 +100,9 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream, } - - - -#ifndef NO_PYTORCH #include #include +#include namespace torchao { /* @@ -129,7 +123,7 @@ After Equivalent transformation : trans(Out) = W * trans(In). Note that we torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, - int splitK=1) + int64_t splitK=1) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); @@ -167,54 +161,8 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, return _out_feats; } - -/* - * Weight prepacking (Pytorch interface). - * [Input & Output] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * [Output] - * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; - */ -torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor) -{ - size_t OC = fp6_tensor.size(0); - size_t IC = fp6_tensor.size(1); - assert (IC%3==0); - IC = IC*16/3; - assert( (OC%256==0) && (IC%64==0) ); - auto packed_tensor = torch::empty_like(fp6_tensor); - auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); - return packed_tensor; +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda); } -/* - * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. - * A useful tool to construct input matrices for the FP16 GEMM baseline. - * [Input] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. - * [Output] - * fp16_tensor: half tensor of shape [OC, IC]. - */ -torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale) -{ - int OC = fp6_tensor.size(0); - assert(fp6_tensor.size(1) % 3 == 0); - int IC = fp6_tensor.size(1) / 3 * 16; - assert(fp16_scale.size(0)==OC); - // - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); - // - auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); - at::Tensor fp16_tensor = torch::empty({OC, IC}, options); - auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); - // - DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr); - // - return fp16_tensor; -} -} -#endif +} // namespace torchao diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_test/kernel_matmul.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh rename to torchao/csrc/cuda/fp6_test/kernel_matmul.cuh diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_test/kernel_reduction.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh rename to torchao/csrc/cuda/fp6_test/kernel_reduction.cuh diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_test/ptx_cp.async.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh rename to torchao/csrc/cuda/fp6_test/ptx_cp.async.cuh diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_test/ptx_mma.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/ptx_mma.cuh rename to torchao/csrc/cuda/fp6_test/ptx_mma.cuh diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_test/utils_core.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/utils_core.cuh rename to torchao/csrc/cuda/fp6_test/utils_core.cuh diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_test/utils_gmem.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/utils_gmem.cuh rename to torchao/csrc/cuda/fp6_test/utils_gmem.cuh diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_test/utils_parallel_dequant.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh rename to torchao/csrc/cuda/fp6_test/utils_parallel_dequant.cuh diff --git a/torchao/csrc/fp6_test/fp6_llm.cpp b/torchao/csrc/fp6_test/fp6_llm.cpp index 6e848c785b..4d82c44517 100644 --- a/torchao/csrc/fp6_test/fp6_llm.cpp +++ b/torchao/csrc/fp6_test/fp6_llm.cpp @@ -4,5 +4,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); + m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int64_t splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 08253c5f59..5c5b676463 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -34,3 +34,7 @@ def _(fp6_weight): # num_to_keep = ctx.create_unbacked_symint() # return fp6_weight.new_empty(num_to_keep, dtype=torch.long) return torch.empty_like(fp6_weight) + + +def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: + return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) From 031379a801df8af1d69a330c374aaa604e5c9670 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 May 2024 22:35:02 +0800 Subject: [PATCH 06/39] fix function def --- torchao/csrc/fp6_test/fp6_llm.cpp | 2 +- torchao/ops.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/fp6_test/fp6_llm.cpp b/torchao/csrc/fp6_test/fp6_llm.cpp index 4d82c44517..90c83f6bff 100644 --- a/torchao/csrc/fp6_test/fp6_llm.cpp +++ b/torchao/csrc/fp6_test/fp6_llm.cpp @@ -4,6 +4,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int64_t splitK) -> Tensor"); + m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 5c5b676463..6bfbe74003 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -38,3 +38,12 @@ def _(fp6_weight): def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) + + +@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear") +def _(_in_feats, _weights, _scales, splitK = 1): + torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {dets.dim()}D") + # ctx = torch._custom_ops.get_ctx() + # num_to_keep = ctx.create_unbacked_symint() + # return fp6_weight.new_empty(num_to_keep, dtype=torch.long) + return torch.empty_like(_in_feats) From c436c43bbd708c9c6a27b432b176cabf9103ae5c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 00:39:53 +0000 Subject: [PATCH 07/39] delete duplicate file --- torchao/csrc/cuda/fp6_llm/weight_prepacking.h | 171 ------------------ 1 file changed, 171 deletions(-) delete mode 100644 torchao/csrc/cuda/fp6_llm/weight_prepacking.h diff --git a/torchao/csrc/cuda/fp6_llm/weight_prepacking.h b/torchao/csrc/cuda/fp6_llm/weight_prepacking.h deleted file mode 100644 index b33bdb18a9..0000000000 --- a/torchao/csrc/cuda/fp6_llm/weight_prepacking.h +++ /dev/null @@ -1,171 +0,0 @@ -#include -#include -#include - -using namespace std; - -void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location -{ - Padded_FP6[0] = FP6_Array[0] & 0xfc; - Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); - Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); - Padded_FP6[3] = FP6_Array[2]<<2; - Padded_FP6[4] = FP6_Array[3] & 0xfc; - Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); - Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); - Padded_FP6[7] = FP6_Array[5]<<2; -} - -unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) -{ - unsigned char out; - out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); - return out; -} - -unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); -{ - unsigned char out; - out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); - return out; -} - -// dealing with 4 1*8 blocks of FP6 -void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) -{ - unsigned char Padded_8_FP8[4][8]; - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); - // - unsigned char Seg1_Byte1_T[4]; - unsigned char Seg1_Byte2_T[4]; - unsigned char Seg2_Byte1_T[4]; - unsigned char Seg2_Byte2_T[4]; - unsigned char Seg2_Byte3_T[4]; - unsigned char Seg2_Byte4_T[4]; - for(int t=0; t<4; t++) - { - Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2], Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); - Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2], Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); - Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); - Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); - Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); - Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); - } - // - for(int t=0; t<4; t++) - { - Seg_2bit[t].push_back(Seg1_Byte1_T[t]); - Seg_2bit[t].push_back(Seg1_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte1_T[t]); - Seg_4bit[t].push_back(Seg2_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte3_T[t]); - Seg_4bit[t].push_back(Seg2_Byte4_T[t]); - } - return; -} - -void BitInterleaving_2bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM - int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. - for(int i=0; i<16; i++) - Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; - // - unsigned int output = 0x00000000; - for(int i=0; i<16; i++) - output |= ( Frags_2bit[i] >> (i*2) ); - // - *PTR_UINT = output; -} - -void BitInterleaving_4bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM - int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. - for(int i=0; i<8; i++) - Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; - // - unsigned int output = 0x00000000; - for(int i=0; i<8; i++) - output |= ( Frags_4bit[i] >> (i*4) ); - // - *PTR_UINT = output; -} - -/* - * Inputs: - * (1) unsigned char Weight_6bit [M*K*6/8] - * Outputs: - * (1) unsigned char Weight_2bit [M*K*2/8] - * (2) unsigned char Weight_4bit [M*K*4/8] - * - * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. - * 8 FP6 = 6 Bytes - * 8 FP4 = 4 Bytes - * 8 FP2 = 2 Bytes - */ -void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) -{ - assert(M % 64 == 0); - assert(K % 64 == 0); - // - unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); - unsigned char* Weight_2bit = reinterpret_cast(packed_weights); - unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; - // - vector A_Segment_2bit[32]; - vector A_Segment_4bit[32]; - // - size_t BytesPerRow = K*6/8; - // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. - for (size_t i = 0; i < M / 64; i++) // - { - for (size_t j = 0; j < K / 16; j++) - { - for(size_t k=0; k<64/16; k++) - { - size_t row = i*64 + k*16; - size_t col = j*16; - unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; - unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; - unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; - unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; - // Dealing with each 16*16 blocks then... - for(int l=0; l<8; l++) Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], &A_Segment_4bit[l*4], StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow); - } - } - } - // Verifying the length of 2_bit segments and 4_bit segments - size_t BytesPerThread_2bit = M*K*2/8/32; - size_t BytesPerThread_4bit = M*K*4/8/32; - for(int i=0; i<32; i++) - { - assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); - assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); - } - // Pass-2: Optimizing coleasced global memory access - for(size_t i=0; i Date: Wed, 8 May 2024 00:42:35 +0000 Subject: [PATCH 08/39] move weight quant file --- torchao/csrc/cuda/fp6_llm/weight_dequant.h | 40 ------------------- .../weight_quant.cpp} | 37 ++++++++++++++++- 2 files changed, 36 insertions(+), 41 deletions(-) delete mode 100644 torchao/csrc/cuda/fp6_llm/weight_dequant.h rename torchao/csrc/{cuda/fp6_llm/weight_quant.h => fp6_test/weight_quant.cpp} (68%) diff --git a/torchao/csrc/cuda/fp6_llm/weight_dequant.h b/torchao/csrc/cuda/fp6_llm/weight_dequant.h deleted file mode 100644 index 2451a55d78..0000000000 --- a/torchao/csrc/cuda/fp6_llm/weight_dequant.h +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include -#include -#include -#include -#include - -void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { - assert(M%64==0); // Currently, M must be a multiple of 64. - assert(K%64==0); // Currently, K must be a multiple of 64. - size_t TotalSizeInByte = M*K*6/8; - // - half* OutPTR = A_16bit_h; - for(size_t i=0; i>2)&0x1f); - unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); - B2 = (B2&0x80) | ((B2>>2)&0x1f); - unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); - B3 = (B3&0x80) | ((B3>>2)&0x1f); - unsigned char B4 = A_6bit_h[i*3+2]<<2; - B4 = (B4&0x80) | ((B4>>2)&0x1f); - half FP1, FP2, FP3, FP4; - unsigned char *PTR1, *PTR2, *PTR3, *PTR4; - PTR1 = reinterpret_cast(&FP1); - PTR2 = reinterpret_cast(&FP2); - PTR3 = reinterpret_cast(&FP3); - PTR4 = reinterpret_cast(&FP4); - PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU - PTR2[0] = 0; PTR2[1] = B2; - PTR3[0] = 0; PTR3[1] = B3; - PTR4[0] = 0; PTR4[1] = B4; - OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); - // - OutPTR +=4; - } -} diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.h b/torchao/csrc/fp6_test/weight_quant.cpp similarity index 68% rename from torchao/csrc/cuda/fp6_llm/weight_quant.h rename to torchao/csrc/fp6_test/weight_quant.cpp index a434cd2d96..5b4240d478 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.h +++ b/torchao/csrc/fp6_test/weight_quant.cpp @@ -1,7 +1,8 @@ // Author: Zhen Zheng // To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix. -#include +#include +#include /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. @@ -103,3 +104,37 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, } } } + +void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { + assert(M%64==0); // Currently, M must be a multiple of 64. + assert(K%64==0); // Currently, K must be a multiple of 64. + size_t TotalSizeInByte = M*K*6/8; + // + half* OutPTR = A_16bit_h; + for(size_t i=0; i>2)&0x1f); + unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); + B2 = (B2&0x80) | ((B2>>2)&0x1f); + unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); + B3 = (B3&0x80) | ((B3>>2)&0x1f); + unsigned char B4 = A_6bit_h[i*3+2]<<2; + B4 = (B4&0x80) | ((B4>>2)&0x1f); + half FP1, FP2, FP3, FP4; + unsigned char *PTR1, *PTR2, *PTR3, *PTR4; + PTR1 = reinterpret_cast(&FP1); + PTR2 = reinterpret_cast(&FP2); + PTR3 = reinterpret_cast(&FP3); + PTR4 = reinterpret_cast(&FP4); + PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU + PTR2[0] = 0; PTR2[1] = B2; + PTR3[0] = 0; PTR3[1] = B3; + PTR4[0] = 0; PTR4[1] = B4; + OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); + // + OutPTR +=4; + } +} From 9180feff4a4a98467bba4fff84984c7955f93bce Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 00:43:54 +0000 Subject: [PATCH 09/39] rename --- setup.py | 6 ++---- torchao/csrc/cuda/{fp6_test => fp6_llm}/configs.h | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/fp6_linear.cu | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/kernel_matmul.cuh | 0 .../csrc/cuda/{fp6_test => fp6_llm}/kernel_reduction.cuh | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/ptx_cp.async.cuh | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/ptx_mma.cuh | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/utils_core.cuh | 0 torchao/csrc/cuda/{fp6_test => fp6_llm}/utils_gmem.cuh | 0 .../cuda/{fp6_test => fp6_llm}/utils_parallel_dequant.cuh | 0 torchao/csrc/{fp6_test => fp6_llm}/fp6_llm.cpp | 0 torchao/csrc/{fp6_test => fp6_llm}/weight_prepacking.cpp | 0 torchao/csrc/{fp6_test => fp6_llm}/weight_quant.cpp | 0 13 files changed, 2 insertions(+), 4 deletions(-) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/configs.h (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/fp6_linear.cu (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/kernel_matmul.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/kernel_reduction.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/ptx_cp.async.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/ptx_mma.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/utils_core.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/utils_gmem.cuh (100%) rename torchao/csrc/cuda/{fp6_test => fp6_llm}/utils_parallel_dequant.cuh (100%) rename torchao/csrc/{fp6_test => fp6_llm}/fp6_llm.cpp (100%) rename torchao/csrc/{fp6_test => fp6_llm}/weight_prepacking.cpp (100%) rename torchao/csrc/{fp6_test => fp6_llm}/weight_quant.cpp (100%) diff --git a/setup.py b/setup.py index cf0badfb67..49f9b5b9d0 100644 --- a/setup.py +++ b/setup.py @@ -63,12 +63,10 @@ def get_extensions(): this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) - sources += list(glob.glob(os.path.join(extensions_dir, "fp6_test/*.cpp"))) + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"))) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) - cuda_sources += list(glob.glob(os.path.join(extensions_cuda_dir, "fp6_test/*.cu"))) + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"))) if use_cuda: sources += cuda_sources diff --git a/torchao/csrc/cuda/fp6_test/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h similarity index 100% rename from torchao/csrc/cuda/fp6_test/configs.h rename to torchao/csrc/cuda/fp6_llm/configs.h diff --git a/torchao/csrc/cuda/fp6_test/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu similarity index 100% rename from torchao/csrc/cuda/fp6_test/fp6_linear.cu rename to torchao/csrc/cuda/fp6_llm/fp6_linear.cu diff --git a/torchao/csrc/cuda/fp6_test/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/kernel_matmul.cuh rename to torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh diff --git a/torchao/csrc/cuda/fp6_test/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/kernel_reduction.cuh rename to torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh diff --git a/torchao/csrc/cuda/fp6_test/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/ptx_cp.async.cuh rename to torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh diff --git a/torchao/csrc/cuda/fp6_test/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/ptx_mma.cuh rename to torchao/csrc/cuda/fp6_llm/ptx_mma.cuh diff --git a/torchao/csrc/cuda/fp6_test/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/utils_core.cuh rename to torchao/csrc/cuda/fp6_llm/utils_core.cuh diff --git a/torchao/csrc/cuda/fp6_test/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/utils_gmem.cuh rename to torchao/csrc/cuda/fp6_llm/utils_gmem.cuh diff --git a/torchao/csrc/cuda/fp6_test/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh similarity index 100% rename from torchao/csrc/cuda/fp6_test/utils_parallel_dequant.cuh rename to torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh diff --git a/torchao/csrc/fp6_test/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp similarity index 100% rename from torchao/csrc/fp6_test/fp6_llm.cpp rename to torchao/csrc/fp6_llm/fp6_llm.cpp diff --git a/torchao/csrc/fp6_test/weight_prepacking.cpp b/torchao/csrc/fp6_llm/weight_prepacking.cpp similarity index 100% rename from torchao/csrc/fp6_test/weight_prepacking.cpp rename to torchao/csrc/fp6_llm/weight_prepacking.cpp diff --git a/torchao/csrc/fp6_test/weight_quant.cpp b/torchao/csrc/fp6_llm/weight_quant.cpp similarity index 100% rename from torchao/csrc/fp6_test/weight_quant.cpp rename to torchao/csrc/fp6_llm/weight_quant.cpp From 1b2442422fdacf4f1ef4dc42d1f147dc096b9dfa Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 00:54:43 +0000 Subject: [PATCH 10/39] add pytorch interface for fp6 weight dequant --- setup.py | 4 +-- torchao/csrc/fp6_llm/fp6_llm.cpp | 1 + torchao/csrc/fp6_llm/weight_quant.cpp | 43 ++++++++++++++++++++++++++- torchao/ops.py | 9 ++---- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 49f9b5b9d0..ff180b6337 100644 --- a/setup.py +++ b/setup.py @@ -63,10 +63,10 @@ def get_extensions(): this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"))) + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"))) + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) if use_cuda: sources += cuda_sources diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 90c83f6bff..057878b534 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -6,4 +6,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); + m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/csrc/fp6_llm/weight_quant.cpp b/torchao/csrc/fp6_llm/weight_quant.cpp index 5b4240d478..70cfaa5b60 100644 --- a/torchao/csrc/fp6_llm/weight_quant.cpp +++ b/torchao/csrc/fp6_llm/weight_quant.cpp @@ -1,8 +1,9 @@ // Author: Zhen Zheng // To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix. -#include #include +#include +#include /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. @@ -138,3 +139,43 @@ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t OutPTR +=4; } } + + +#include +#include + +namespace torchao { + +/* + * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. + * A useful tool to construct input matrices for the FP16 GEMM baseline. + * [Input] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. + * [Output] + * fp16_tensor: half tensor of shape [OC, IC]. + */ +at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) +{ + int OC = fp6_tensor.size(0); + assert(fp6_tensor.size(1) % 3 == 0); + int IC = fp6_tensor.size(1) / 3 * 16; + assert(fp16_scale.size(0)==OC); + // + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); + // + auto options = at::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); + at::Tensor fp16_tensor = at::empty({OC, IC}, options); + auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); + // + DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr); + // + return fp16_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); +} + +} diff --git a/torchao/ops.py b/torchao/ops.py index 6bfbe74003..7ce84c2ad7 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -40,10 +40,5 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) -@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear") -def _(_in_feats, _weights, _scales, splitK = 1): - torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {dets.dim()}D") - # ctx = torch._custom_ops.get_ctx() - # num_to_keep = ctx.create_unbacked_symint() - # return fp6_weight.new_empty(num_to_keep, dtype=torch.long) - return torch.empty_like(_in_feats) +def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: + return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) From 2671c9cd3cbdbfa5e4675506a0a5bb867942c282 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 01:26:30 +0000 Subject: [PATCH 11/39] add fake_fp6 to fp6 --- torchao/csrc/fp6_llm/fp6_llm.cpp | 1 + torchao/csrc/fp6_llm/weight_quant.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 057878b534..9e1bcaaeb2 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -6,5 +6,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); + m.def("fake_fp6_to_fp6(Tensor fake_fp6_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/csrc/fp6_llm/weight_quant.cpp b/torchao/csrc/fp6_llm/weight_quant.cpp index 70cfaa5b60..d6ef19e19e 100644 --- a/torchao/csrc/fp6_llm/weight_quant.cpp +++ b/torchao/csrc/fp6_llm/weight_quant.cpp @@ -141,11 +141,34 @@ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t } +#include #include #include namespace torchao { +// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 +at::Tensor fake_fp6_to_fp6_cpu(at::Tensor fake_fp6_tensor) +{ + TORCH_CHECK(fake_fp6_tensor.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(fake_fp6_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(fake_fp6_tensor.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(fake_fp6_tensor.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = fake_fp6_tensor.size(0); + auto K = fake_fp6_tensor.size(1); + TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); + + // Pack weight from FP16 to FP6. + auto options = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto packed_fp6_tensor = at::empty({M, K * 6 / 8}, options); + uint8_t* packed_fp6_ptr = packed_fp6_tensor.data_ptr(); + + uint16_t* fake_fp6_ptr = reinterpret_cast(fake_fp6_tensor.data_ptr()); + weight_prepacking_fp16_to_fp6(fake_fp6_ptr, packed_fp6_ptr, M, K); + + return packed_fp6_tensor; +} + /* * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. * A useful tool to construct input matrices for the FP16 GEMM baseline. @@ -175,6 +198,7 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal } TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::fake_fp6_to_fp6", &fake_fp6_to_fp6_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); } From e61be51a80ff1f3b3a9b0d11b44269edee07d42d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 20:31:56 +0800 Subject: [PATCH 12/39] move weight_quant to csrc/cuda due to cuda_fp16.h dependency --- .../{fp6_llm/weight_quant.cpp => cuda/fp6_llm/weight_quant.cu} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torchao/csrc/{fp6_llm/weight_quant.cpp => cuda/fp6_llm/weight_quant.cu} (100%) diff --git a/torchao/csrc/fp6_llm/weight_quant.cpp b/torchao/csrc/cuda/fp6_llm/weight_quant.cu similarity index 100% rename from torchao/csrc/fp6_llm/weight_quant.cpp rename to torchao/csrc/cuda/fp6_llm/weight_quant.cu From 21acfd14ffd625ee942f7054932ef5eda020e4d1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 21:44:33 +0800 Subject: [PATCH 13/39] add fake_fp6_to_fp6 test --- test/test_ops.py | 19 +++++++++++++++++++ torchao/ops.py | 18 ++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 6f46a46bef..753e0980d5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -54,6 +54,25 @@ def test_prepack_fp6_weight(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) + def test_fake_fp6_to_fp6(self): + OC = 256 + IC = 256 + + # in this fp6, we use 3 bits for exponent and 2 bits for mantissa + # also, we don't have nan/inf + fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 + fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) + fake_fp6_weight = torch.randn((OC, IC), dtype=torch.float16) + fake_fp6_weight.clip_(-fp6_absmax, fp6_absmax) + fake_fp6_weight[fake_fp6_weight.abs() < fp6_absmin] = 0 + + # smoke test + torchao.ops.fake_fp6_to_fp6(fake_fp6_weight) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fake_fp6_to_fp6, (fake_fp6_weight,), test_utils=test_utils) + if __name__ == "__main__": unittest.main() diff --git a/torchao/ops.py b/torchao/ops.py index 7ce84c2ad7..5ee714817f 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -29,13 +29,23 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: @torch.library.impl_abstract("torchao::prepack_fp6_weight") def _(fp6_weight): - torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {dets.dim()}D") - # ctx = torch._custom_ops.get_ctx() - # num_to_keep = ctx.create_unbacked_symint() - # return fp6_weight.new_empty(num_to_keep, dtype=torch.long) + torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") return torch.empty_like(fp6_weight) +def fake_fp6_to_fp6(fake_fp6_tensor: Tensor) -> Tensor: + return torch.ops.torchao.fake_fp6_to_fp6.default(fake_fp6_tensor) + + +@torch.library.impl_abstract("torchao::fake_fp6_to_fp6") +def _(fake_fp6_tensor): + torch._check(fake_fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fake_fp6_tensor.dim()}D") + torch._check(fake_fp6_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fake_fp6_tensor.dtype}") + M, K = fake_fp6_tensor.shape + torch._check(K % 4 == 0, lambda: f"second dimension must be a multiple of 4, got {K}") + return torch.empty((M, K * 6 // 8), dtype=torch.uint8, device=fake_fp6_tensor.device) + + def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) From 67fd6f811a5e567d142b8bf369cb5e7759fa4f71 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 21:59:23 +0800 Subject: [PATCH 14/39] add test for fp16act_fp6weight_linear --- test/test_ops.py | 27 +++++++++++++++++++++++++++ torchao/ops.py | 17 +++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 753e0980d5..55dc54bf30 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -45,6 +45,8 @@ def test_nms(self): def test_prepack_fp6_weight(self): OC = 256 IC = 256 + + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) # smoke test @@ -54,6 +56,7 @@ def test_prepack_fp6_weight(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) + # this may fail in CPU, since there is no compiled function without CUDA def test_fake_fp6_to_fp6(self): OC = 256 IC = 256 @@ -73,6 +76,30 @@ def test_fake_fp6_to_fp6(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fake_fp6_to_fp6, (fake_fp6_weight,), test_utils=test_utils) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp16act_fp6weight_linear(self): + BS = 2 + OC = 256 + IC = 256 + splitK = 1 + + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(OC).to(torch.float16) + 0.5 + fp16_activation = torch.rand(BS, IC).to(torch.float16) + 0.5 + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + # smoke test + torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) + if __name__ == "__main__": unittest.main() diff --git a/torchao/ops.py b/torchao/ops.py index 5ee714817f..37b99fd664 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -50,5 +50,22 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) +@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear") +def _(_in_feats, _weights, _scales, splitK = 1): + torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") + torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") + torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") + torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") + torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") + torch._check(_scales.dtype is torch.float16, lambda: f"weight must be INT32, got {_scales.dtype}") + + BS, IC = _in_feats.shape + OC, _ = _weights.shape + torch._check(IC / 16 * 3 == _weights.shape[1], lambda: "Dimensions mismatched") + torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") + + return _in_feats.new_empty((BS, OC)) + + def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) From 084b7e4708040b325bae4458c6da8815342449dd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 May 2024 22:03:55 +0800 Subject: [PATCH 15/39] add test for fp6_weight_dequant --- test/test_ops.py | 15 +++++++++++++++ torchao/ops.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 55dc54bf30..b047a38d1c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -100,6 +100,21 @@ def test_fp16act_fp6weight_linear(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) + def test_fp6_weight_dequant(self): + OC = 256 + IC = 256 + + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(OC).to(torch.float16) + 0.5 + + # smoke test + torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) + if __name__ == "__main__": unittest.main() diff --git a/torchao/ops.py b/torchao/ops.py index 37b99fd664..43a0b2c7ba 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -57,7 +57,7 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype is torch.float16, lambda: f"weight must be INT32, got {_scales.dtype}") + torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") BS, IC = _in_feats.shape OC, _ = _weights.shape @@ -69,3 +69,16 @@ def _(_in_feats, _weights, _scales, splitK = 1): def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) + + +@torch.library.impl_abstract("torchao::fp6_weight_dequant") +def _(fp6_tensor, fp16_scale): + torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") + torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}") + torch._check(fp16_scale.dim() == 1, lambda: f"scale should be a 2d tensor, got {fp16_scale.dim()}D") + torch._check(fp16_scale.dtype is torch.float16, lambda: f"scale must be FP16, got {fp16_scale.dtype}") + + OC, _IC = fp6_tensor.shape + torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") + + return fp16_scale.new_empty((OC, _IC * 16 // 3)) From 6d2fc3ef550782f4fc6d7c1e6064160085c50bf3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 00:47:15 +0800 Subject: [PATCH 16/39] Fp6WeightOnlyQuantizedLinearWeight (not working yet) --- torchao/__init__.py | 11 +++-- torchao/quantization/subclass.py | 84 ++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/torchao/__init__.py b/torchao/__init__.py index 340bfe3013..c5c711dd84 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,14 +1,15 @@ +import torch +from torch.testing._internal.common_utils import IS_FBCODE +if not IS_FBCODE: + from . import _C + from . import ops + from torchao.quantization import ( apply_weight_only_int8_quant, apply_dynamic_quant, autoquant, ) from . import dtypes -import torch -from torch.testing._internal.common_utils import IS_FBCODE -if not IS_FBCODE: - from . import _C - from . import ops __all__ = [ "dtypes", diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 6128720d4d..4b355625de 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -4,11 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging import warnings import torch from torch.utils._python_dispatch import return_and_correct_aliasing +import torchao.ops from .quant_primitives import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -865,3 +867,85 @@ def __torch_dispatch__(cls, func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + + +class Fp6WeightOnlyQuantizedLinearWeight(torch.Tensor): + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + shape: torch.Size, + dtype=None, + *args, + **kwargs + ): + kwargs["device"] = int_data.device + # kwargs["layout"] = ( + # kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + # ) + if dtype is None: + dtype = scale.dtype + kwargs["dtype"] = dtype + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + shape: torch.Size, + dtype=None, + *args, + **kwargs + ): + self.int_data = int_data + self.scale = scale + + def dequantize(self, output_dtype=torch.float16): + return torchao.ops.fp6_weight_dequant(self.int_data.cpu(), self.scale.cpu()).to(output_dtype).to(self.int_data.device) + + # https://github.com/microsoft/DeepSpeed/blob/0b224edcf7d83713b95ad6b989694a8bdf01809e/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py + @classmethod + def from_float(cls, input_float): + try: + from qtorch.quant import float_quantize + except: + logging.exception("qtorch is not available. Please install qtorch to use FP6 weight") + raise + + device = input_float.device + input_float = input_float.cpu().float() + + num_bits = 6 + exp_bits = 3 + q_range = 28 + + man_bits = num_bits - exp_bits - 1 + + max_input = input_float.abs().amax(dim=1) # symmetric quantization + scales = max_input / q_range # q_range + 1 + scales[scales == 0] = 1 # avoid zero scales + scaled_input = input_float / scales.view(-1, 1) + scales = scales.half() + + quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest").half() + fp6_weight = torchao.ops.fake_fp6_to_fp6(quantized_fake_fp6) + fp6_weight = torchao.ops.prepack_fp6_weight(fp6_weight.view(torch.int32)) + + return cls( + fp6_weight.to(device), + scales.to(device), + input_float.shape, + dtype=torch.float16, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + print(func) + if func is aten.linear.default: + return torchao.ops.fp16act_fp6weight_linear(args[1], args[0].int_data, args[0].scale) + + raise NotImplementedError From 68f241503cc2ab3ce4d3a96b09ad9c9957825eca Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 00:50:47 +0800 Subject: [PATCH 17/39] skip some tests, since the functions are not built w/o CUDA --- test/test_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index b047a38d1c..a3a8a6b126 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -56,7 +56,7 @@ def test_prepack_fp6_weight(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) - # this may fail in CPU, since there is no compiled function without CUDA + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fake_fp6_to_fp6(self): OC = 256 IC = 256 @@ -100,6 +100,7 @@ def test_fp16act_fp6weight_linear(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp6_weight_dequant(self): OC = 256 IC = 256 From 5989599071c3f56c98f5e052644b5803470f9fea Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 01:17:10 +0000 Subject: [PATCH 18/39] add the original test --- fp6_test.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++++ fp6_test_run.sh | 47 ++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 fp6_test.py create mode 100644 fp6_test_run.sh diff --git a/fp6_test.py b/fp6_test.py new file mode 100644 index 0000000000..a83769e7c2 --- /dev/null +++ b/fp6_test.py @@ -0,0 +1,98 @@ +# from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py + +import argparse +import torch +import torchao + +WARMUP = 10 +REPEAT = 1000 + +parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).') +parser.add_argument('--OC', type=int, required=False, default=4096, help='number of rows of the weight matrix.') +parser.add_argument('--IC', type=int, required=False, default=4096, help='number of columns of the weight matrix.') +parser.add_argument('--BS', type=int, required=False, default=32, help='inference batch size.') +parser.add_argument('--splitK', type=int, required=False, default=1, help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.') +args = parser.parse_args() + +assert(args.OC%256==0) +assert(args.IC%64==0) + +print("#"*64) +print(args) + +fp6_weight = torch.randint(4294967295, (args.OC,args.IC//16*3)).to(torch.int) # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. +fp16_scale = torch.rand(args.OC).to(torch.half)+0.5 +fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5 + +start_event = torch.cuda.Event(enable_timing=True) +end_event = torch.cuda.Event(enable_timing=True) + +# fp6-fp16 GEMM (fp6-llm) +#################################################################################################################################### +torch.cuda.synchronize() +fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) +act_cuda = fp16_activation.cuda() +weight_cuda = fp6_weight_packed.cuda() +scale_cuda = fp16_scale.cuda() +for i in range(WARMUP): + results_fp6_llm = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, args.splitK); +start_event.record() +for i in range(REPEAT): + results_fp6_llm = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, args.splitK); +end_event.record() +torch.cuda.synchronize() +fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT +fp6_llm_tflops = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9 +#################################################################################################################################### + +# baseline fp16 GEMM (cuBLAS) +#################################################################################################################################### +torch.cuda.synchronize() +fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) +cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False) +results_cublas = None +with torch.no_grad(): + cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda()) + act_cuda = fp16_activation.cuda() + for i in range(WARMUP): + results_cublas = cuBLAS_MatMul(act_cuda) + start_event.record() + for i in range(REPEAT): + results_cublas = cuBLAS_MatMul(act_cuda) + end_event.record() +torch.cuda.synchronize() +cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT +cublas_tflops = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9 +#################################################################################################################################### + +# Performance +print( 'cuBLAS time: {:.2f} ms \t\t cuBLAS TFLOPs: {:.1f}'.format(cublas_time_ms, cublas_tflops) ) +print( 'fp6-llm time: {:.2f} ms \t\t fp6-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) ) +print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) ) + +# Correctness +error = results_cublas.cpu() - results_fp6_llm.cpu() +ground_truth = results_cublas.cpu() +mean_error = torch.mean(abs(error)) +mean_ground_truth = torch.mean(abs(ground_truth)) +relative_error = mean_error.item()/mean_ground_truth.item() +print( "relative error: {:.6f}".format(relative_error) ) + + + +# 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 +# "00110000110000110000110000110000" "11000011000011000011000011000011" "00001100001100001100001100001100" +# 818089008 3272356035 204522252 +#fp6_weight = torch.zeros(args.OC, args.IC//16*3).to(torch.int64) +#for i in range(args.OC): +# for j in range(args.IC//16): +# fp6_weight[i][j*3+0] = 818089008 +# fp6_weight[i][j*3+1] = 3272356035 +# fp6_weight[i][j*3+2] = 204522252 +#fp6_weight = fp6_weight.to(torch.int) + +# Ensuring that the absolute error or relative error of each matrix element is smaller than 1e-3. +#Error = [1e-2] +#for err in Error: +# AllClose = torch.allclose(results_fp6_llm.cpu(), results_cublas.cpu(), rtol=err, atol=err, equal_nan=True) +# print("torch.allclose\t (relative/absolute_error<" + str(err) + ") \t-> " + str(AllClose)) diff --git a/fp6_test_run.sh b/fp6_test_run.sh new file mode 100644 index 0000000000..a2462d2305 --- /dev/null +++ b/fp6_test_run.sh @@ -0,0 +1,47 @@ +#! /bin/bash + +# [Batch sizes to test] +# If you want to test the performance of FP6-LLM for larger inference batch sizes, +# which typically happens during prompt processing, +# please revise this file by simply "commenting" and "uncommenting". + +N=(1 2 4 8 16 32 64) +SplitK=(5 6 7 6) + +# BS <=64 +#N=(1 2 4 8 16 32 64) +#SplitK=(5 6 7 6) + +# BS = 128 +#N=(128) +#SplitK=(5 3 3 3) + +# BS = 256 +#N=(256) +#SplitK=(4 3 2 3) + +# BS = 512 +#N=(512) +#SplitK=(2 5 2 4) + +# BS = 1024 +#N=(1024) +#SplitK=(1 2 1 2) + +# BS >= 2048 +# N = (2048, 4096, 8192, 16384) +#SplitK=(1 1 1 1) + +# Benchmarking the specific Matrix Shape from llama2-70b +M=(10240 8192 57344 8192) +K=(8192 8192 8192 28672) + +#mkdir -p Profiling +for ((i=0;i<${#M[@]};i++)) +do + for BS in ${N[@]} + do + #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \ + python fp6_test.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} + done +done From 92dfde4a2b8fb250beb1f160fc542927fcd43746 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 01:52:36 +0000 Subject: [PATCH 19/39] implement transpose and clone so that F.linear will work --- torchao/quantization/subclass.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 4b355625de..1e5e182c22 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -878,6 +878,7 @@ def __new__( scale: torch.Tensor, shape: torch.Size, dtype=None, + transposed: bool = False, *args, **kwargs ): @@ -898,11 +899,13 @@ def __init__( scale: torch.Tensor, shape: torch.Size, dtype=None, + transposed: bool = False, *args, **kwargs ): self.int_data = int_data self.scale = scale + self.transposed = transposed def dequantize(self, output_dtype=torch.float16): return torchao.ops.fp6_weight_dequant(self.int_data.cpu(), self.scale.cpu()).to(output_dtype).to(self.int_data.device) @@ -942,10 +945,31 @@ def from_float(cls, input_float): dtype=torch.float16, ) + def _change_shape(self, shape): + return self.__class__( + self.int_data, self.scale, shape, dtype=self.dtype, transposed=self.transposed + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): print(func) - if func is aten.linear.default: - return torchao.ops.fp16act_fp6weight_linear(args[1], args[0].int_data, args[0].scale) + if func is aten.mm.default: + fp16_act = args[0] + fp6_weight = args[1] + + if not fp6_weight.transposed: + raise NotImplementedError("FP8 weight must be transposed in matmul") + + return torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.int_data, fp6_weight.scale) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + args[0].transposed = not args[0].transposed + new = args[0]._change_shape(args[0].shape[::-1]) + return return_and_correct_aliasing(func, args, kwargs, new) raise NotImplementedError From da1421b8f530d012de2bc18e249677a5ead7ed2d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 01:54:27 +0000 Subject: [PATCH 20/39] remove print --- torchao/quantization/subclass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 1e5e182c22..a7f31db550 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -952,7 +952,6 @@ def _change_shape(self, shape): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): - print(func) if func is aten.mm.default: fp16_act = args[0] fp6_weight = args[1] From a0a53a0d37ab5e58d1c3228a9a3e6f9083ef20e5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 06:32:37 +0000 Subject: [PATCH 21/39] remove dequantize --- torchao/quantization/subclass.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a7f31db550..c5b88478d4 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -907,8 +907,9 @@ def __init__( self.scale = scale self.transposed = transposed - def dequantize(self, output_dtype=torch.float16): - return torchao.ops.fp6_weight_dequant(self.int_data.cpu(), self.scale.cpu()).to(output_dtype).to(self.int_data.device) + def dequantize(self, output_dtype=torch.float32): + raise NotImplementedError + # we don't have a kernel to revert torchao.ops.prepack_fp6_weight() # https://github.com/microsoft/DeepSpeed/blob/0b224edcf7d83713b95ad6b989694a8bdf01809e/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @classmethod From 079e16bf0e867e685a38862881f3dd9fdb5fcd7d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 06:49:24 +0000 Subject: [PATCH 22/39] add notes and some rename --- test/test_ops.py | 12 +++---- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 16 ++++----- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 +- torchao/ops.py | 40 ++++++++++++++++++----- torchao/quantization/subclass.py | 2 +- 5 files changed, 48 insertions(+), 24 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a3a8a6b126..24d03cd17b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -57,7 +57,7 @@ def test_prepack_fp6_weight(self): opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fake_fp6_to_fp6(self): + def test_fp16_to_fp6(self): OC = 256 IC = 256 @@ -65,16 +65,16 @@ def test_fake_fp6_to_fp6(self): # also, we don't have nan/inf fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) - fake_fp6_weight = torch.randn((OC, IC), dtype=torch.float16) - fake_fp6_weight.clip_(-fp6_absmax, fp6_absmax) - fake_fp6_weight[fake_fp6_weight.abs() < fp6_absmin] = 0 + fp16_weight = torch.randn((OC, IC), dtype=torch.float16) + fp16_weight.clip_(-fp6_absmax, fp6_absmax) + fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 # smoke test - torchao.ops.fake_fp6_to_fp6(fake_fp6_weight) + torchao.ops.fp16_to_fp6(fp16_weight) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fake_fp6_to_fp6, (fake_fp6_weight,), test_utils=test_utils) + opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp16act_fp6weight_linear(self): diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index d6ef19e19e..2f274c17a2 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -148,14 +148,14 @@ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t namespace torchao { // https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 -at::Tensor fake_fp6_to_fp6_cpu(at::Tensor fake_fp6_tensor) +at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) { - TORCH_CHECK(fake_fp6_tensor.dim() == 2, "weight must be 2-dimensional"); - TORCH_CHECK(fake_fp6_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); - TORCH_CHECK(fake_fp6_tensor.is_contiguous(), "weight must be contiguous"); - TORCH_CHECK(fake_fp6_tensor.device().type() == torch::kCPU, "weight must be on CPU"); - auto M = fake_fp6_tensor.size(0); - auto K = fake_fp6_tensor.size(1); + TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(fp16_tensor.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(fp16_tensor.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = fp16_tensor.size(0); + auto K = fp16_tensor.size(1); TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); // Pack weight from FP16 to FP6. @@ -198,7 +198,7 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal } TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::fake_fp6_to_fp6", &fake_fp6_to_fp6_cpu); + m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 9e1bcaaeb2..794c79df11 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -6,6 +6,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); - m.def("fake_fp6_to_fp6(Tensor fake_fp6_tensor) -> Tensor"); + m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 43a0b2c7ba..3a25dbf6db 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -24,6 +24,15 @@ def _(dets, scores, iou_threshold): def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: + """ + Pack FP6 tensor in a layout for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. + + Arguments + fp6_weight: tightly-packed fp6_weight, inside a `torch.int32` container + + Returns + packed FP6 tensor for use with FP6-LLM, inside a `torch.int32` container + """ return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) @@ -33,20 +42,35 @@ def _(fp6_weight): return torch.empty_like(fp6_weight) -def fake_fp6_to_fp6(fake_fp6_tensor: Tensor) -> Tensor: - return torch.ops.torchao.fake_fp6_to_fp6.default(fake_fp6_tensor) +def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: + """ + Pack FP16 tensor (containing only FP6 values) into FP6 tensor. + """ + return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor) -@torch.library.impl_abstract("torchao::fake_fp6_to_fp6") -def _(fake_fp6_tensor): - torch._check(fake_fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fake_fp6_tensor.dim()}D") - torch._check(fake_fp6_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fake_fp6_tensor.dtype}") - M, K = fake_fp6_tensor.shape +@torch.library.impl_abstract("torchao::fp16_to_fp6") +def _(fp16_tensor): + torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") + torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") + M, K = fp16_tensor.shape torch._check(K % 4 == 0, lambda: f"second dimension must be a multiple of 4, got {K}") - return torch.empty((M, K * 6 // 8), dtype=torch.uint8, device=fake_fp6_tensor.device) + return torch.empty((M, K * 6 // 8), dtype=torch.uint8, device=fp16_tensor.device) def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: + """ + FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. + + Arguments + _in_feats: input activations in FP16 + _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _scales: scale + splitK: split K + + Returns + output of linear layer + """ return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index c5b88478d4..48f9171e69 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -936,7 +936,7 @@ def from_float(cls, input_float): scales = scales.half() quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest").half() - fp6_weight = torchao.ops.fake_fp6_to_fp6(quantized_fake_fp6) + fp6_weight = torchao.ops.fp16_to_fp6(quantized_fake_fp6) fp6_weight = torchao.ops.prepack_fp6_weight(fp6_weight.view(torch.int32)) return cls( From 06e84384c0e6469141698a453c23bef01f2a8bd9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 06:51:03 +0000 Subject: [PATCH 23/39] typo --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 2f274c17a2..a88fc40d31 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -163,7 +163,7 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) auto packed_fp6_tensor = at::empty({M, K * 6 / 8}, options); uint8_t* packed_fp6_ptr = packed_fp6_tensor.data_ptr(); - uint16_t* fake_fp6_ptr = reinterpret_cast(fake_fp6_tensor.data_ptr()); + uint16_t* fake_fp6_ptr = reinterpret_cast(fp16_tensor.data_ptr()); weight_prepacking_fp16_to_fp6(fake_fp6_ptr, packed_fp6_ptr, M, K); return packed_fp6_tensor; From ca452748b610deac416ad07062cd9577e5256a18 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 06:59:41 +0000 Subject: [PATCH 24/39] small cleanup --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index a88fc40d31..1591f82b84 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -181,18 +181,18 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) { int OC = fp6_tensor.size(0); - assert(fp6_tensor.size(1) % 3 == 0); + TORCH_CHECK(fp6_tensor.size(1) % 3 == 0); int IC = fp6_tensor.size(1) / 3 * 16; - assert(fp16_scale.size(0)==OC); + TORCH_CHECK(fp16_scale.size(0) == OC); // - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); // - auto options = at::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); + auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device()); at::Tensor fp16_tensor = at::empty({OC, IC}, options); auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); // - DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr); + DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr); // return fp16_tensor; } From 7a0f6e2aee8d8ccf1c20f6a164cb4ed469af0709 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 21:25:26 +0800 Subject: [PATCH 25/39] improve tensor subclass and add test (which is failing for torch-compile) --- test/integration/test_integration.py | 12 +++++++++- torchao/quantization/subclass.py | 33 +++++++++++++++++++++++----- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index e6da3e7340..f8bc634a24 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -48,7 +48,8 @@ from torchao.quantization.subclass import ( Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight + Int4WeightOnlyQuantizedLinearWeight, + Fp6WeightOnlyQuantizedLinearWeight, ) from torchao.quantization.utils import ( _apply_logging_hook, @@ -1074,6 +1075,15 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): test_dtype=dtype, ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_fp6_weight_only_quant_subclass(self, device, dtype): + if device != "cuda" or dtype != torch.float16: + self.skipTest(f"Not implemented for {device} {dtype}") + for test_shape in [(16, 1024, 256), (1, 1024, 256)]: + self._test_lin_weight_subclass_impl( + Fp6WeightOnlyQuantizedLinearWeight.from_float, device, 10, test_shape=test_shape, test_dtype=dtype + ) + class TestDynamicQuant(unittest.TestCase): def test_dynamic_quant(self): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 48f9171e69..fbe533148f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -951,16 +951,39 @@ def _change_shape(self, shape): self.int_data, self.scale, shape, dtype=self.dtype, transposed=self.transposed ) + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + self.shape, + dtype=self.dtype, + transposed=self.transposed, + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): - if func is aten.mm.default: - fp16_act = args[0] - fp6_weight = args[1] + if func in (aten.mm.default, aten.addmm.default): + if func is aten.mm.default: + fp16_act = args[0] + fp6_weight = args[1] + bias = False + else: + fp16_act = args[1] + fp6_weight = args[2] + bias = args[0] if not fp6_weight.transposed: raise NotImplementedError("FP8 weight must be transposed in matmul") - return torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.int_data, fp6_weight.scale) + out = torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.int_data, fp6_weight.scale) + if bias is not None: # we don't have fused bias kernel + out = out + bias + return out + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) if func is aten.clone.default: return return_and_correct_aliasing( @@ -972,4 +995,4 @@ def __torch_dispatch__(cls, func, types, args, kwargs): new = args[0]._change_shape(args[0].shape[::-1]) return return_and_correct_aliasing(func, args, kwargs, new) - raise NotImplementedError + raise NotImplementedError(f"{func} is not implemented for {cls.__name__}") From 320827ebf16f308a88a563a97f5c321d257fbe58 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 21:26:36 +0800 Subject: [PATCH 26/39] add note --- torchao/quantization/subclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index fbe533148f..a40878a1a0 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -925,7 +925,7 @@ def from_float(cls, input_float): num_bits = 6 exp_bits = 3 - q_range = 28 + q_range = 28 # max value for E3M2 man_bits = num_bits - exp_bits - 1 From 74b80940ddf812608c8d93f6e12a5c4a6b3c636a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 21:27:40 +0800 Subject: [PATCH 27/39] add note --- torchao/quantization/subclass.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a40878a1a0..74eb7fd4aa 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -870,6 +870,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): class Fp6WeightOnlyQuantizedLinearWeight(torch.Tensor): + """ + FP6 weight for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. + """ @staticmethod def __new__( From c8d47c3d5e9031c5a7cf0021c1520b347a6533f9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 21:29:08 +0800 Subject: [PATCH 28/39] add qtorch as dev requirement --- dev-requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6dadb274aa..4efa3836c3 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,3 +12,5 @@ pandas # Custom CUDA Extensions ninja + +qtorch # for FP6-LLM From e08ba6ae1649308958e17cb214c985483313a7f2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 22:15:15 +0800 Subject: [PATCH 29/39] update error message --- torchao/csrc/fp6_llm/weight_prepacking.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/fp6_llm/weight_prepacking.cpp b/torchao/csrc/fp6_llm/weight_prepacking.cpp index 2b03eb32ab..39616f81ff 100644 --- a/torchao/csrc/fp6_llm/weight_prepacking.cpp +++ b/torchao/csrc/fp6_llm/weight_prepacking.cpp @@ -187,9 +187,9 @@ at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) { size_t OC = fp6_tensor.size(0); size_t IC = fp6_tensor.size(1); - TORCH_CHECK(IC % 3 == 0); + TORCH_CHECK(IC % 3 == 0, "Expect packed input dim % 3 == 0, but receive ", IC, " instead."); IC = IC * 16 / 3; - TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0)); + TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0), "Expect output dim % 256 == 0 and input dim % 64 == 0, but receive ", OC, " and ", IC, " instead."); auto packed_tensor = at::empty_like(fp6_tensor); auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); From b090b4b04645799b6be896b1cb7eb61e6259a660 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 23:18:35 +0800 Subject: [PATCH 30/39] add __repr__ and fix transposed issue --- torchao/quantization/subclass.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 74eb7fd4aa..5309fc305f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -963,6 +963,12 @@ def _apply_fn_to_data(self, fn): transposed=self.transposed, ) + def __repr__(self): + return ( + f"{self.__class__.__name__}(shape={self.shape}, transposed={self.transposed}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): if func in (aten.mm.default, aten.addmm.default): @@ -994,8 +1000,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) if func is aten.t.default: - args[0].transposed = not args[0].transposed new = args[0]._change_shape(args[0].shape[::-1]) + new.transposed = not args[0].transposed return return_and_correct_aliasing(func, args, kwargs, new) raise NotImplementedError(f"{func} is not implemented for {cls.__name__}") From f6f93c39cae3bea2e6d4294633cedfb2fccc9ae6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 9 May 2024 23:34:47 +0800 Subject: [PATCH 31/39] add fp6 perplexity test --- test_fp6_perplexity.py | 55 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 test_fp6_perplexity.py diff --git a/test_fp6_perplexity.py b/test_fp6_perplexity.py new file mode 100644 index 0000000000..0e9d5febd1 --- /dev/null +++ b/test_fp6_perplexity.py @@ -0,0 +1,55 @@ +# adapted from https://huggingface.co/docs/transformers/perplexity + +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +import torch +from tqdm import tqdm +from torchao.quantization.subclass import Fp6WeightOnlyQuantizedLinearWeight + +dtype = "fp32" # fp32, fp16, or fp6 + +device = "cuda" +model_id = "microsoft/Phi-3-mini-4k-instruct" +tokenizer_id = "microsoft/Phi-3-mini-4k-instruct" +model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True) +if dtype != "fp32": + model.half() +tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + +if dtype == "fp6": + modules = list(model.named_modules()) + for name, module in tqdm(modules, desc="Converting weight to FP6"): + if isinstance(module, torch.nn.Linear): + try: + fp6_weight = Fp6WeightOnlyQuantizedLinearWeight.from_float(module.weight.detach()) + module.weight = torch.nn.Parameter(fp6_weight, requires_grad=False) + except Exception as e: + print(f"Unable to convert {name}.weight to FP6. {e}") # typically LM head + +test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") +encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") + +max_length = model.config.max_length +stride = 512 +seq_len = encodings.input_ids.size(1) + +nlls = [] +prev_end_loc = 0 +for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + nll = model(input_ids, labels=target_ids).loss + + nlls.append(nll) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + +ppl = torch.exp(torch.stack(nlls).mean()) +print(ppl) From b857645ac3b720038247c1d887863e6d13b5026e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 10 May 2024 01:06:55 +0000 Subject: [PATCH 32/39] rename variables --- torchao/quantization/subclass.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 5309fc305f..1dc8c70c92 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -877,7 +877,7 @@ class Fp6WeightOnlyQuantizedLinearWeight(torch.Tensor): @staticmethod def __new__( cls, - int_data: torch.Tensor, + fp6_data: torch.Tensor, scale: torch.Tensor, shape: torch.Size, dtype=None, @@ -885,10 +885,7 @@ def __new__( *args, **kwargs ): - kwargs["device"] = int_data.device - # kwargs["layout"] = ( - # kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - # ) + kwargs["device"] = fp6_data.device if dtype is None: dtype = scale.dtype kwargs["dtype"] = dtype @@ -898,7 +895,7 @@ def __new__( def __init__( self, - int_data: torch.Tensor, + fp6_data: torch.Tensor, scale: torch.Tensor, shape: torch.Size, dtype=None, @@ -906,7 +903,7 @@ def __init__( *args, **kwargs ): - self.int_data = int_data + self.fp6_data = fp6_data self.scale = scale self.transposed = transposed @@ -951,12 +948,12 @@ def from_float(cls, input_float): def _change_shape(self, shape): return self.__class__( - self.int_data, self.scale, shape, dtype=self.dtype, transposed=self.transposed + self.fp6_data, self.scale, shape, dtype=self.dtype, transposed=self.transposed ) def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.int_data), + fn(self.fp6_data), fn(self.scale), self.shape, dtype=self.dtype, @@ -984,7 +981,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if not fp6_weight.transposed: raise NotImplementedError("FP8 weight must be transposed in matmul") - out = torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.int_data, fp6_weight.scale) + out = torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.fp6_data, fp6_weight.scale) if bias is not None: # we don't have fused bias kernel out = out + bias return out From f0eba1a2d8ec88ea215bac32d71d912cf43370e5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 10 May 2024 19:49:59 +0800 Subject: [PATCH 33/39] remove subclass --- test/integration/test_integration.py | 10 -- torchao/quantization/subclass.py | 135 --------------------------- 2 files changed, 145 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f8bc634a24..05096385fc 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -49,7 +49,6 @@ Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, - Fp6WeightOnlyQuantizedLinearWeight, ) from torchao.quantization.utils import ( _apply_logging_hook, @@ -1075,15 +1074,6 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): test_dtype=dtype, ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_fp6_weight_only_quant_subclass(self, device, dtype): - if device != "cuda" or dtype != torch.float16: - self.skipTest(f"Not implemented for {device} {dtype}") - for test_shape in [(16, 1024, 256), (1, 1024, 256)]: - self._test_lin_weight_subclass_impl( - Fp6WeightOnlyQuantizedLinearWeight.from_float, device, 10, test_shape=test_shape, test_dtype=dtype - ) - class TestDynamicQuant(unittest.TestCase): def test_dynamic_quant(self): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 1dc8c70c92..3386b98de5 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -867,138 +867,3 @@ def __torch_dispatch__(cls, func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) - - -class Fp6WeightOnlyQuantizedLinearWeight(torch.Tensor): - """ - FP6 weight for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. - """ - - @staticmethod - def __new__( - cls, - fp6_data: torch.Tensor, - scale: torch.Tensor, - shape: torch.Size, - dtype=None, - transposed: bool = False, - *args, - **kwargs - ): - kwargs["device"] = fp6_data.device - if dtype is None: - dtype = scale.dtype - kwargs["dtype"] = dtype - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - fp6_data: torch.Tensor, - scale: torch.Tensor, - shape: torch.Size, - dtype=None, - transposed: bool = False, - *args, - **kwargs - ): - self.fp6_data = fp6_data - self.scale = scale - self.transposed = transposed - - def dequantize(self, output_dtype=torch.float32): - raise NotImplementedError - # we don't have a kernel to revert torchao.ops.prepack_fp6_weight() - - # https://github.com/microsoft/DeepSpeed/blob/0b224edcf7d83713b95ad6b989694a8bdf01809e/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py - @classmethod - def from_float(cls, input_float): - try: - from qtorch.quant import float_quantize - except: - logging.exception("qtorch is not available. Please install qtorch to use FP6 weight") - raise - - device = input_float.device - input_float = input_float.cpu().float() - - num_bits = 6 - exp_bits = 3 - q_range = 28 # max value for E3M2 - - man_bits = num_bits - exp_bits - 1 - - max_input = input_float.abs().amax(dim=1) # symmetric quantization - scales = max_input / q_range # q_range + 1 - scales[scales == 0] = 1 # avoid zero scales - scaled_input = input_float / scales.view(-1, 1) - scales = scales.half() - - quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest").half() - fp6_weight = torchao.ops.fp16_to_fp6(quantized_fake_fp6) - fp6_weight = torchao.ops.prepack_fp6_weight(fp6_weight.view(torch.int32)) - - return cls( - fp6_weight.to(device), - scales.to(device), - input_float.shape, - dtype=torch.float16, - ) - - def _change_shape(self, shape): - return self.__class__( - self.fp6_data, self.scale, shape, dtype=self.dtype, transposed=self.transposed - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.fp6_data), - fn(self.scale), - self.shape, - dtype=self.dtype, - transposed=self.transposed, - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(shape={self.shape}, transposed={self.transposed}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in (aten.mm.default, aten.addmm.default): - if func is aten.mm.default: - fp16_act = args[0] - fp6_weight = args[1] - bias = False - else: - fp16_act = args[1] - fp6_weight = args[2] - bias = args[0] - - if not fp6_weight.transposed: - raise NotImplementedError("FP8 weight must be transposed in matmul") - - out = torchao.ops.fp16act_fp6weight_linear(fp16_act, fp6_weight.fp6_data, fp6_weight.scale) - if bias is not None: # we don't have fused bias kernel - out = out + bias - return out - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.t.default: - new = args[0]._change_shape(args[0].shape[::-1]) - new.transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, new) - - raise NotImplementedError(f"{func} is not implemented for {cls.__name__}") From 8f1ef8dfa424a2f12929d7a956edd7041fa37765 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 10 May 2024 20:08:32 +0800 Subject: [PATCH 34/39] add correctness test --- fp6_test.py | 98 ------------------------------------------------ fp6_test_run.sh | 47 ----------------------- test/test_ops.py | 43 +++++++++++++++------ 3 files changed, 31 insertions(+), 157 deletions(-) delete mode 100644 fp6_test.py delete mode 100644 fp6_test_run.sh diff --git a/fp6_test.py b/fp6_test.py deleted file mode 100644 index a83769e7c2..0000000000 --- a/fp6_test.py +++ /dev/null @@ -1,98 +0,0 @@ -# from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py - -import argparse -import torch -import torchao - -WARMUP = 10 -REPEAT = 1000 - -parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).') -parser.add_argument('--OC', type=int, required=False, default=4096, help='number of rows of the weight matrix.') -parser.add_argument('--IC', type=int, required=False, default=4096, help='number of columns of the weight matrix.') -parser.add_argument('--BS', type=int, required=False, default=32, help='inference batch size.') -parser.add_argument('--splitK', type=int, required=False, default=1, help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.') -args = parser.parse_args() - -assert(args.OC%256==0) -assert(args.IC%64==0) - -print("#"*64) -print(args) - -fp6_weight = torch.randint(4294967295, (args.OC,args.IC//16*3)).to(torch.int) # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. -fp16_scale = torch.rand(args.OC).to(torch.half)+0.5 -fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5 - -start_event = torch.cuda.Event(enable_timing=True) -end_event = torch.cuda.Event(enable_timing=True) - -# fp6-fp16 GEMM (fp6-llm) -#################################################################################################################################### -torch.cuda.synchronize() -fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) -act_cuda = fp16_activation.cuda() -weight_cuda = fp6_weight_packed.cuda() -scale_cuda = fp16_scale.cuda() -for i in range(WARMUP): - results_fp6_llm = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, args.splitK); -start_event.record() -for i in range(REPEAT): - results_fp6_llm = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, args.splitK); -end_event.record() -torch.cuda.synchronize() -fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT -fp6_llm_tflops = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9 -#################################################################################################################################### - -# baseline fp16 GEMM (cuBLAS) -#################################################################################################################################### -torch.cuda.synchronize() -fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) -cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False) -results_cublas = None -with torch.no_grad(): - cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda()) - act_cuda = fp16_activation.cuda() - for i in range(WARMUP): - results_cublas = cuBLAS_MatMul(act_cuda) - start_event.record() - for i in range(REPEAT): - results_cublas = cuBLAS_MatMul(act_cuda) - end_event.record() -torch.cuda.synchronize() -cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT -cublas_tflops = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9 -#################################################################################################################################### - -# Performance -print( 'cuBLAS time: {:.2f} ms \t\t cuBLAS TFLOPs: {:.1f}'.format(cublas_time_ms, cublas_tflops) ) -print( 'fp6-llm time: {:.2f} ms \t\t fp6-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) ) -print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) ) - -# Correctness -error = results_cublas.cpu() - results_fp6_llm.cpu() -ground_truth = results_cublas.cpu() -mean_error = torch.mean(abs(error)) -mean_ground_truth = torch.mean(abs(ground_truth)) -relative_error = mean_error.item()/mean_ground_truth.item() -print( "relative error: {:.6f}".format(relative_error) ) - - - -# 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 -# "00110000110000110000110000110000" "11000011000011000011000011000011" "00001100001100001100001100001100" -# 818089008 3272356035 204522252 -#fp6_weight = torch.zeros(args.OC, args.IC//16*3).to(torch.int64) -#for i in range(args.OC): -# for j in range(args.IC//16): -# fp6_weight[i][j*3+0] = 818089008 -# fp6_weight[i][j*3+1] = 3272356035 -# fp6_weight[i][j*3+2] = 204522252 -#fp6_weight = fp6_weight.to(torch.int) - -# Ensuring that the absolute error or relative error of each matrix element is smaller than 1e-3. -#Error = [1e-2] -#for err in Error: -# AllClose = torch.allclose(results_fp6_llm.cpu(), results_cublas.cpu(), rtol=err, atol=err, equal_nan=True) -# print("torch.allclose\t (relative/absolute_error<" + str(err) + ") \t-> " + str(AllClose)) diff --git a/fp6_test_run.sh b/fp6_test_run.sh deleted file mode 100644 index a2462d2305..0000000000 --- a/fp6_test_run.sh +++ /dev/null @@ -1,47 +0,0 @@ -#! /bin/bash - -# [Batch sizes to test] -# If you want to test the performance of FP6-LLM for larger inference batch sizes, -# which typically happens during prompt processing, -# please revise this file by simply "commenting" and "uncommenting". - -N=(1 2 4 8 16 32 64) -SplitK=(5 6 7 6) - -# BS <=64 -#N=(1 2 4 8 16 32 64) -#SplitK=(5 6 7 6) - -# BS = 128 -#N=(128) -#SplitK=(5 3 3 3) - -# BS = 256 -#N=(256) -#SplitK=(4 3 2 3) - -# BS = 512 -#N=(512) -#SplitK=(2 5 2 4) - -# BS = 1024 -#N=(1024) -#SplitK=(1 2 1 2) - -# BS >= 2048 -# N = (2048, 4096, 8192, 16384) -#SplitK=(1 1 1 1) - -# Benchmarking the specific Matrix Shape from llama2-70b -M=(10240 8192 57344 8192) -K=(8192 8192 8192 28672) - -#mkdir -p Profiling -for ((i=0;i<${#M[@]};i++)) -do - for BS in ${N[@]} - do - #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \ - python fp6_test.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} - done -done diff --git a/test/test_ops.py b/test/test_ops.py index 24d03cd17b..6d964396d7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,6 +4,7 @@ import torchao from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 import unittest +from parameterized import parameterized # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): @@ -42,12 +43,17 @@ def test_nms(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) + def _create_fp6_inputs(self, BS: int, OC: int, IC: int): + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(OC).half() + 0.5 + fp16_activation = torch.rand(BS, IC).half() + 0.5 + return fp6_weight, fp16_scale, fp16_activation + def test_prepack_fp6_weight(self): OC = 256 IC = 256 - - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC) # smoke test torchao.ops.prepack_fp6_weight(fp6_weight) @@ -82,11 +88,7 @@ def test_fp16act_fp6weight_linear(self): OC = 256 IC = 256 splitK = 1 - - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(OC).to(torch.float16) + 0.5 - fp16_activation = torch.rand(BS, IC).to(torch.float16) + 0.5 + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) act_cuda = fp16_activation.cuda() @@ -104,10 +106,7 @@ def test_fp16act_fp6weight_linear(self): def test_fp6_weight_dequant(self): OC = 256 IC = 256 - - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(OC).to(torch.float16) + 0.5 + fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC) # smoke test torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) @@ -116,6 +115,26 @@ def test_fp6_weight_dequant(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py + @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() + results_fp16 = act_cuda @ fp16_weight.T + + error = (results_fp6 - results_fp16).abs() + relative_error = error / results_fp16.abs() + assert relative_error.mean() < 1e-3 + if __name__ == "__main__": unittest.main() From cb05d305baa48d39a15570d1a19884514c699691 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 10 May 2024 20:11:44 +0800 Subject: [PATCH 35/39] remove unwanted changes --- dev-requirements.txt | 2 - test/integration/test_integration.py | 2 +- test_fp6_perplexity.py | 55 ---------------------------- torchao/__init__.py | 13 +++---- torchao/quantization/subclass.py | 2 - 5 files changed, 7 insertions(+), 67 deletions(-) delete mode 100644 test_fp6_perplexity.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 4efa3836c3..6dadb274aa 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,5 +12,3 @@ pandas # Custom CUDA Extensions ninja - -qtorch # for FP6-LLM diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 05096385fc..e6da3e7340 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -48,7 +48,7 @@ from torchao.quantization.subclass import ( Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight, + Int4WeightOnlyQuantizedLinearWeight ) from torchao.quantization.utils import ( _apply_logging_hook, diff --git a/test_fp6_perplexity.py b/test_fp6_perplexity.py deleted file mode 100644 index 0e9d5febd1..0000000000 --- a/test_fp6_perplexity.py +++ /dev/null @@ -1,55 +0,0 @@ -# adapted from https://huggingface.co/docs/transformers/perplexity - -from transformers import AutoModelForCausalLM, AutoTokenizer -from datasets import load_dataset -import torch -from tqdm import tqdm -from torchao.quantization.subclass import Fp6WeightOnlyQuantizedLinearWeight - -dtype = "fp32" # fp32, fp16, or fp6 - -device = "cuda" -model_id = "microsoft/Phi-3-mini-4k-instruct" -tokenizer_id = "microsoft/Phi-3-mini-4k-instruct" -model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True) -if dtype != "fp32": - model.half() -tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - -if dtype == "fp6": - modules = list(model.named_modules()) - for name, module in tqdm(modules, desc="Converting weight to FP6"): - if isinstance(module, torch.nn.Linear): - try: - fp6_weight = Fp6WeightOnlyQuantizedLinearWeight.from_float(module.weight.detach()) - module.weight = torch.nn.Parameter(fp6_weight, requires_grad=False) - except Exception as e: - print(f"Unable to convert {name}.weight to FP6. {e}") # typically LM head - -test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") -encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") - -max_length = model.config.max_length -stride = 512 -seq_len = encodings.input_ids.size(1) - -nlls = [] -prev_end_loc = 0 -for begin_loc in tqdm(range(0, seq_len, stride)): - end_loc = min(begin_loc + max_length, seq_len) - trg_len = end_loc - prev_end_loc # may be different from stride on last loop - input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) - target_ids = input_ids.clone() - target_ids[:, :-trg_len] = -100 - - with torch.no_grad(): - nll = model(input_ids, labels=target_ids).loss - - nlls.append(nll) - - prev_end_loc = end_loc - if end_loc == seq_len: - break - -ppl = torch.exp(torch.stack(nlls).mean()) -print(ppl) diff --git a/torchao/__init__.py b/torchao/__init__.py index c8f04c1d9e..c982e09a0c 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,3 +1,9 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + autoquant, +) +from . import dtypes import torch _IS_FBCODE = ( hasattr(torch._utils_internal, "IS_FBSOURCE") and @@ -8,13 +14,6 @@ from . import _C from . import ops -from torchao.quantization import ( - apply_weight_only_int8_quant, - apply_dynamic_quant, - autoquant, -) -from . import dtypes - __all__ = [ "dtypes", "apply_dynamic_quant", diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 3386b98de5..6128720d4d 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -4,13 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import logging import warnings import torch from torch.utils._python_dispatch import return_and_correct_aliasing -import torchao.ops from .quant_primitives import ( dequantize_per_channel, dynamically_quantize_per_channel, From 56aefc6b034aebb3126ca1e4b89501eabf90ea12 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 10 May 2024 20:28:53 +0800 Subject: [PATCH 36/39] add apache 2.0 notice --- torchao/csrc/cuda/fp6_llm/configs.h | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/utils_core.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 16 ++++++++++++++++ .../cuda/fp6_llm/utils_parallel_dequant.cuh | 16 ++++++++++++++++ torchao/csrc/cuda/fp6_llm/weight_quant.cu | 18 ++++++++++++++++-- torchao/csrc/fp6_llm/weight_prepacking.cpp | 16 ++++++++++++++++ 11 files changed, 176 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h index e6b217cdca..0a642fc805 100644 --- a/torchao/csrc/cuda/fp6_llm/configs.h +++ b/torchao/csrc/cuda/fp6_llm/configs.h @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h + #ifndef CONFIGS_H #define CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index e9f051c2c9..51413a0874 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/fp6_linear.cu + #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 1a971f837f..de7775ddce 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh + #include "configs.h" #include "utils_gmem.cuh" #include "utils_core.cuh" diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index 442de103b8..c0e7c1918a 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh + /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh index 8a6069ff1f..c1d064f32a 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh + /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index 1920678244..d0985bd63d 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh + /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index e0d374f22f..5bfc043ef6 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh + #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index 86ba333b68..5c37452e13 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh + #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 5a6977bc07..f6ce4cc046 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh + #ifndef UTILS_PARALLELDEQUANT_CUH #define UTILS_PARALLELDEQUANT_CUH diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 1591f82b84..d29f70be0c 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -1,5 +1,19 @@ -// Author: Zhen Zheng -// To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix. +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h +// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h #include #include diff --git a/torchao/csrc/fp6_llm/weight_prepacking.cpp b/torchao/csrc/fp6_llm/weight_prepacking.cpp index 39616f81ff..89a1171f5e 100644 --- a/torchao/csrc/fp6_llm/weight_prepacking.cpp +++ b/torchao/csrc/fp6_llm/weight_prepacking.cpp @@ -1,3 +1,19 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h + #include #include #include From 7d3a5b1ac7b443a49256a964c7df62f2f0889c90 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 11 May 2024 00:05:19 +0800 Subject: [PATCH 37/39] add benchmark script --- benchmarks/benchmark_fp6.py | 82 +++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 benchmarks/benchmark_fp6.py diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py new file mode 100644 index 0000000000..abe21d2f7d --- /dev/null +++ b/benchmarks/benchmark_fp6.py @@ -0,0 +1,82 @@ +import torch +import torchao +from torch.utils.benchmark import Timer +import pandas as pd +from tqdm import tqdm + + +def benchmark(m, k, n, splitK): + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(n).half() + 0.5 + fp16_activation = torch.rand(m, k).half() + 0.5 + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + # need to do this since Timer cannot see torchao + def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK): + return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp6_measurement = Timer( + stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)", + globals=locals(), + ).blocked_autorange() + + fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() + fp16_output = act_cuda @ fp16_weight.T + + fp16_measurement = Timer( + stmt="act_cuda @ fp16_weight.T", + globals=locals(), + ).blocked_autorange() + + # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py + # doesn't seem to be the right way to check for correctness + correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + + return { + "m": m, + "k": k, + "n": n, + "fp6_latency (ms)": fp6_measurement.median * 1000, + "fp16_latency (ms)": fp16_measurement.median * 1000, + "speedup (d/s)": fp16_measurement.median / fp6_measurement.median, + "correct": correct, + } + + +if __name__ == "__main__": + # from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh + k_vals = (8192, 8192, 8192, 28672) + n_vals = (10240, 8192, 57344, 8192) + + results = [] + + # splitK can be tuned based on m, k, n + for m, splitK_vals in tqdm([ + (1, (5, 6, 7, 6)), + (2, (5, 6, 7, 6)), + (4, (5, 6, 7, 6)), + (8, (5, 6, 7, 6)), + # (16, (5, 6, 7, 6)), + # (64, (5, 6, 7, 6)), + # (128, (5, 3, 3, 3)), + # (256, (4, 3, 2, 3)), + # (512, (2, 5, 2, 4)), + (1024, (1, 2, 1, 2)), + (2048, (1, 1, 1, 1)), + (4096, (1, 1, 1, 1)), + # (8192, (1, 1, 1, 1)), + # (16384, (1, 1, 1, 1)), + ]): + for n, k, splitK in zip(n_vals, k_vals, splitK_vals): + results.append(benchmark(m, n, k, splitK)) + + df = pd.DataFrame(results) + df.to_csv("fp6_benchmark_results.csv", index=False) + print(df.to_markdown(index=False)) From 08a95ac5767a18c6ddc22fb9aa0f02ee4b95e52f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 14 May 2024 01:10:24 +0000 Subject: [PATCH 38/39] add note about FP6 kernel --- torchao/csrc/fp6_llm/README.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 torchao/csrc/fp6_llm/README.md diff --git a/torchao/csrc/fp6_llm/README.md b/torchao/csrc/fp6_llm/README.md new file mode 100644 index 0000000000..ff764cc27d --- /dev/null +++ b/torchao/csrc/fp6_llm/README.md @@ -0,0 +1,7 @@ +# FP6-LLM kernel + +This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). + +On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. + +See https://github.com/pytorch/ao/pull/223 for some benchmark results. From a8b4dd3bce1c6821f9f8d9079ee329c0b633ba3c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 14 May 2024 01:13:03 +0000 Subject: [PATCH 39/39] relax tolerance --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 6d964396d7..e260e86f0f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -133,7 +133,7 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs() - assert relative_error.mean() < 1e-3 + assert relative_error.mean() < 1e-2 if __name__ == "__main__":