diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h old mode 100644 new mode 100755 index 4b376261d30d..a4460375be7c --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -45,6 +45,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 9216) \ f(in_T, out_T, W_T, narrow, 10240) \ f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 11264) \ f(in_T, out_T, W_T, narrow, 12288) \ f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ @@ -53,6 +54,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ + f(in_T, out_T, W_T, narrow, 22528) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27648) \ @@ -65,6 +67,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 60544) \ + f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64256) \ f(in_T, out_T, W_T, narrow, 64512) \ @@ -74,6 +78,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128000) \ f(in_T, out_T, W_T, narrow, 128256) \ f(in_T, out_T, W_T, narrow, 128512) \ + + // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py @@ -116,6 +122,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 9216, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 11264, narrow) \ f(in_T, out_T, W_T, 12288, narrow) \ f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ @@ -124,6 +131,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 16384, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 22528, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27648, narrow) \ @@ -136,6 +144,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 60544, narrow) \ + f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ f(in_T, out_T, W_T, 64256, narrow) \ f(in_T, out_T, W_T, 64512, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f021c003b132..80c4c0139c67 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -75,10 +75,12 @@ def _lora_ref_impl( 9216, 10240, 11008, + 11264, 13824, 14336, 15360, 22016, + 22528, 24576, 27392, 27648, @@ -90,6 +92,8 @@ def _lora_ref_impl( 36864, 43264, 49152, + 60544, + 60672, 64000, 64256, 102400, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 84786921ce1b..4a14634d7319 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,7 +29,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -265,10 +265,14 @@ def __init__( config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ @@ -302,18 +306,44 @@ def forward( class CohereForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" + ] + embedding_modules = {"embed_tokens": "input_embeddings"} + embedding_padding_modules = [] + def __init__( self, config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.quant_config = quant_config - self.logits_processor = LogitsProcessor(config.vocab_size, + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, cache_config, quant_config) + self.model = CohereModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.sampler = Sampler() @torch.no_grad() @@ -330,8 +360,14 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + is_not_lora = hasattr(self.model.embed_tokens, 'weight') + if is_not_lora: + embedding_weights = self.model.embed_tokens.weight + else: + embedding_weights = self.model.embed_tokens.base_layer.weight + + logits = self.logits_processor(embedding_weights, hidden_states, + sampling_metadata) return logits def sample(