diff --git a/cacheflow/models/layernorm.py b/cacheflow/models/layernorm.py new file mode 100644 index 000000000000..37c41c5fce42 --- /dev/null +++ b/cacheflow/models/layernorm.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +from cacheflow import layernorm_ops + + +class RMSNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + layernorm_ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 236feab4b4cb..d31131510369 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -12,6 +12,7 @@ from cacheflow.models import InputMetadata from cacheflow.models.attention import LlamaCacheFlowAttention +from cacheflow.models.layernorm import RMSNorm from cacheflow.models.sample import Sampler from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -23,22 +24,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class LlamaRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - return self.weight * hidden_states - - class LlamaMLP(nn.Module): def __init__( @@ -148,8 +133,8 @@ def __init__(self, config: LlamaConfig): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -190,7 +175,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 23c2b2cda1e2..5b24120eadfb 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -3,6 +3,7 @@ #include "attention_utils.h" #include "cuda_primitives.h" +#include "reduction_utils.h" #include diff --git a/csrc/attention_utils.h b/csrc/attention_utils.h index ff59f43d7988..049555390715 100644 --- a/csrc/attention_utils.h +++ b/csrc/attention_utils.h @@ -159,45 +159,6 @@ struct Qk_dot { } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - } // namespace cacheflow #undef MMHA_USE_FP32_ACUM_FOR_FMA diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp new file mode 100644 index 000000000000..749ca5f92154 --- /dev/null +++ b/csrc/layernorm.cpp @@ -0,0 +1,14 @@ +#include + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); +} diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu new file mode 100644 index 000000000000..84372ed2dd60 --- /dev/null +++ b/csrc/layernorm_kernels.cu @@ -0,0 +1,61 @@ +#include +#include + +#include "reduction_utils.h" + +namespace cacheflow { + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [num_tokens, hidden_size] + const scalar_t* __restrict__ input, // [num_tokens, hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float) input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float) input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + } +} + +} // namespace cacheflow + +void rms_norm( + torch::Tensor& out, // [num_tokens, hidden_size] + torch::Tensor& input, // [num_tokens, hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int num_tokens = input.size(0); + int hidden_size = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), + "rms_norm_kernel", + [&] { + cacheflow::rms_norm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); +} diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.h new file mode 100644 index 000000000000..f977ab70f1fe --- /dev/null +++ b/csrc/reduction_utils.h @@ -0,0 +1,76 @@ +#pragma once + +namespace cacheflow { + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +} // namespace cacheflow diff --git a/setup.py b/setup.py index 9889918fb183..df7551989946 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,14 @@ ) ext_modules.append(positional_encoding_extension) +# Layer normalization kernels. +layernorm_extension = cpp_extension.CUDAExtension( + name='cacheflow.layernorm_ops', + sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'], + extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, +) +ext_modules.append(layernorm_extension) + setuptools.setup( name='cacheflow', ext_modules=ext_modules, diff --git a/tests/kernels/layernorm.py b/tests/kernels/layernorm.py new file mode 100644 index 000000000000..0e0072d879c2 --- /dev/null +++ b/tests/kernels/layernorm.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + +from cacheflow import layernorm_ops + + +class RefRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + weight = torch.randn(hidden_size) / (hidden_size ** 0.5) + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states + + +@torch.inference_mode() +def test_rms_norm( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, +) -> None: + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda') + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + + out = torch.empty_like(x) + layernorm_ops.rms_norm( + out, + x, + ref.weight.data, + ref.variance_epsilon, + ) + ref_out = ref(x) + assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) + + +if __name__ == '__main__': + for dtype in [torch.half, torch.float]: + for num_tokens in [7, 128, 2048]: + for hidden_size in [13, 64, 1024, 5120]: + print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' + f'{num_tokens}, hidden_size={hidden_size}') + test_rms_norm( + num_tokens=num_tokens, + hidden_size=hidden_size, + dtype=dtype, + )