3838
3939from vllm .attention import AttentionMetadata
4040from vllm .config import VllmConfig
41- from vllm .distributed import parallel_state
41+ from vllm .distributed import parallel_state , tensor_model_parallel_all_gather
4242from vllm .distributed import utils as dist_utils
4343from vllm .logger import init_logger
4444from vllm .model_executor import SamplingMetadata
@@ -239,6 +239,8 @@ def __init__(
239239 super ().__init__ ()
240240 # Per attention head and per partition values.
241241 world_size = parallel_state .get_tensor_model_parallel_world_size ()
242+ self .tp_size = world_size
243+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
242244 self .hidden_size_per_attention_head = dist_utils .divide (
243245 projection_size , num_heads )
244246 self .num_attention_heads_per_partition = dist_utils .divide (
@@ -261,24 +263,41 @@ def __init__(
261263 raise RuntimeError (
262264 f"Qwen2-VL does not support { self .attn_backend } backend now." )
263265
266+ def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
267+ # [s, b, 3 * head * head_dim]
268+ seq_len , bs , _ = qkv .shape
269+ if self .tp_size > 1 :
270+ qkv = tensor_model_parallel_all_gather (qkv )
271+
272+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
273+ q , k , v = qkv .chunk (3 , dim = 2 )
274+
275+ # 3 * [s, b, head * head_dim]
276+ if self .tp_size > 1 :
277+ splitter = partial (dist_utils .split_tensor_along_last_dim ,
278+ num_partitions = self .tp_size )
279+ q = splitter (q )[self .tp_rank ]
280+ k = splitter (k )[self .tp_rank ]
281+ v = splitter (v )[self .tp_rank ]
282+
283+ # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
284+ new_shape = (seq_len , bs , self .num_attention_heads_per_partition ,
285+ self .hidden_size_per_attention_head )
286+ q , k , v = (x .view (* new_shape ) for x in (q , k , v ))
287+ return q , k , v
288+
264289 def forward (
265290 self ,
266291 x : torch .Tensor ,
267292 cu_seqlens : torch .Tensor ,
268293 rotary_pos_emb : torch .Tensor ,
269294 ) -> torch .Tensor :
270- # [s, b, c] --> [s, b, head * 3 * head_dim]
271- x , _ = self .qkv (x )
272295
273- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
274- new_x_shape = x .size ()[:- 1 ] + (
275- self .num_attention_heads_per_partition ,
276- 3 * self .hidden_size_per_attention_head ,
277- )
278- x = x .view (* new_x_shape )
296+ # [s, b, c] --> [s, b, 3 * head * head_dim]
297+ x , _ = self .qkv (x )
279298
280- # [s, b, head, 3 * head_dim] -- > 3 [s, b, head, head_dim]
281- q , k , v = dist_utils . split_tensor_along_last_dim ( x , 3 )
299+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
300+ q , k , v = self . split_qkv ( x )
282301 batch_size = q .shape [1 ]
283302
284303 q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
@@ -614,24 +633,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
614633 weight_loader (param , loaded_weight , shard_id )
615634 break
616635 else :
617- if name .endswith ("qkv.weight" ):
618- visual_num_heads = self .num_heads
619- visual_embed_dim = self .embed_dim
620- head_size = visual_embed_dim // visual_num_heads
621- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
622- head_size ,
623- visual_embed_dim )
624- loaded_weight = loaded_weight .transpose (0 , 1 )
625- loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
626- elif name .endswith ("qkv.bias" ):
627- visual_num_heads = self .num_heads
628- visual_embed_dim = self .embed_dim
629- head_size = visual_embed_dim // visual_num_heads
630- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
631- head_size )
632- loaded_weight = loaded_weight .transpose (0 , 1 )
633- loaded_weight = loaded_weight .reshape (- 1 )
634-
635636 param = params_dict [name ]
636637 weight_loader = getattr (param , "weight_loader" ,
637638 default_weight_loader )
@@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
935936 embedding_modules = {}
936937 embedding_padding_modules = []
937938
939+ # BitandBytes specific attributes
940+ bitsandbytes_stacked_params_mapping = {
941+ # shard_name, weight_name, index
942+ "q_proj" : ("qkv_proj" , 0 ),
943+ "k_proj" : ("qkv_proj" , 1 ),
944+ "v_proj" : ("qkv_proj" , 2 ),
945+ "gate_proj" : ("gate_up_proj" , 0 ),
946+ "up_proj" : ("gate_up_proj" , 1 ),
947+ }
948+
938949 # To ensure correct weight loading and mapping.
939950 hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {
940951 "lm_head." : "language_model.lm_head." ,
0 commit comments