@@ -177,6 +177,18 @@ def triton_mrope(
177177 return q , k
178178
179179
180+ def apply_interleaved_rope (x : torch .Tensor ,
181+ mrope_section : list [int ]) -> torch .Tensor :
182+ """Apply interleaved MRoPE to 3D rotary embeddings.
183+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
184+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
185+ """
186+ x_t = x [0 ].clone ()
187+ x_t [..., 1 :mrope_section [1 ] * 3 :3 ] = x [1 , ..., 1 :mrope_section [1 ] * 3 :3 ]
188+ x_t [..., 2 :mrope_section [2 ] * 3 :3 ] = x [2 , ..., 2 :mrope_section [2 ] * 3 :3 ]
189+ return x_t
190+
191+
180192class MRotaryEmbedding (RotaryEmbedding ):
181193 """Rotary Embedding with Multimodal Sections."""
182194
@@ -189,6 +201,7 @@ def __init__(
189201 is_neox_style : bool ,
190202 dtype : torch .dtype ,
191203 mrope_section : Optional [list [int ]] = None ,
204+ mrope_interleaved : Optional [bool ] = False ,
192205 ) -> None :
193206 # In Qwen2.5-VL, the maximum index value is related to the duration of
194207 # the input video. We enlarge max_position_embeddings to 4 times to get
@@ -198,6 +211,7 @@ def __init__(
198211 base , is_neox_style , dtype )
199212
200213 self .mrope_section = mrope_section
214+ self .mrope_interleaved = mrope_interleaved
201215 if self .mrope_section :
202216 assert sum (self .mrope_section ) == rotary_dim // 2
203217
@@ -225,17 +239,20 @@ def forward_native(
225239 cos , sin = cos_sin .chunk (2 , dim = - 1 )
226240 if positions .ndim == 2 :
227241 assert self .mrope_section
228-
229- cos = torch .cat ([
230- m [i ]
231- for i , m in enumerate (cos .split (self .mrope_section , dim = - 1 ))
232- ],
233- dim = - 1 )
234- sin = torch .cat ([
235- m [i ]
236- for i , m in enumerate (sin .split (self .mrope_section , dim = - 1 ))
237- ],
238- dim = - 1 )
242+ if self .mrope_interleaved :
243+ cos = apply_interleaved_rope (cos , self .mrope_section )
244+ sin = apply_interleaved_rope (sin , self .mrope_section )
245+ else :
246+ cos = torch .cat ([
247+ m [i ] for i , m in enumerate (
248+ cos .split (self .mrope_section , dim = - 1 ))
249+ ],
250+ dim = - 1 )
251+ sin = torch .cat ([
252+ m [i ] for i , m in enumerate (
253+ sin .split (self .mrope_section , dim = - 1 ))
254+ ],
255+ dim = - 1 )
239256
240257 query_shape = query .shape
241258 query = query .view (num_tokens , - 1 , self .head_size )
@@ -265,6 +282,10 @@ def forward_cuda(
265282 assert positions .ndim == 1 or positions .ndim == 2
266283 assert key is not None
267284
285+ if self .mrope_interleaved :
286+ # TODO: add triton implementation to support mrope-interleaved
287+ return self .forward_native (positions , query , key )
288+
268289 num_tokens = positions .shape [- 1 ]
269290 cos_sin = self .cos_sin_cache [positions ]
270291 cos , sin = cos_sin .chunk (2 , dim = - 1 )
@@ -388,6 +409,15 @@ def get_input_positions_tensor(
388409 context_len = context_len ,
389410 seq_len = seq_len ,
390411 )
412+ elif hf_config .model_type in ["qwen3_vl" , "qwen3_vl_moe" ]:
413+ return cls ._qwen3vl_get_input_positions_tensor (
414+ input_tokens = input_tokens ,
415+ hf_config = hf_config ,
416+ image_grid_thw = image_grid_thw ,
417+ video_grid_thw = video_grid_thw ,
418+ context_len = context_len ,
419+ seq_len = seq_len ,
420+ )
391421 elif hf_config .model_type in ["ernie4_5_moe_vl" , "ernie4_5_vl" ]:
392422 return cls ._ernie_get_input_positions_tensor (
393423 input_tokens = input_tokens ,
@@ -526,6 +556,98 @@ def _glm4v_get_input_positions_tensor(
526556 len (input_tokens )).item ()
527557 return llm_positions , mrope_position_delta
528558
559+ @classmethod
560+ def _qwen3vl_get_input_positions_tensor (
561+ cls ,
562+ input_tokens : list [int ],
563+ hf_config : PretrainedConfig ,
564+ image_grid_thw : Union [list [list [int ]], torch .Tensor ],
565+ video_grid_thw : Union [list [list [int ]], torch .Tensor ],
566+ context_len : int = 0 ,
567+ seq_len : Optional [int ] = None ,
568+ ) -> tuple [torch .Tensor , int ]:
569+ """Get mrope input positions and delta value."""
570+
571+ video_grid_thw = [[1 , h , w ] for t , h , w in video_grid_thw
572+ for _ in range (t )]
573+
574+ image_token_id = hf_config .image_token_id
575+ video_token_id = hf_config .video_token_id
576+ vision_start_token_id = hf_config .vision_start_token_id
577+ spatial_merge_size = hf_config .vision_config .spatial_merge_size
578+
579+ input_tokens_tensor = torch .tensor (input_tokens )
580+ vision_start_indices = torch .argwhere (
581+ input_tokens_tensor == vision_start_token_id ).squeeze (1 )
582+ vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
583+ image_nums = (vision_tokens == image_token_id ).sum ()
584+ video_nums = (vision_tokens == video_token_id ).sum ()
585+ llm_pos_ids_list : list = []
586+
587+ st = 0
588+ remain_images , remain_videos = image_nums , video_nums
589+
590+ image_index , video_index = 0 , 0
591+ for _ in range (image_nums + video_nums ):
592+ if image_token_id in input_tokens and remain_images > 0 :
593+ ed_image = input_tokens .index (image_token_id , st )
594+ else :
595+ ed_image = len (input_tokens ) + 1
596+ if video_token_id in input_tokens and remain_videos > 0 :
597+ ed_video = input_tokens .index (video_token_id , st )
598+ else :
599+ ed_video = len (input_tokens ) + 1
600+ if ed_image < ed_video :
601+ t , h , w = (
602+ image_grid_thw [image_index ][0 ],
603+ image_grid_thw [image_index ][1 ],
604+ image_grid_thw [image_index ][2 ],
605+ )
606+ image_index += 1
607+ remain_images -= 1
608+ ed = ed_image
609+ else :
610+ t , h , w = (
611+ video_grid_thw [video_index ][0 ],
612+ video_grid_thw [video_index ][1 ],
613+ video_grid_thw [video_index ][2 ],
614+ )
615+ video_index += 1
616+ remain_videos -= 1
617+ ed = ed_video
618+
619+ llm_grid_t , llm_grid_h , llm_grid_w = \
620+ t , h // spatial_merge_size , w // spatial_merge_size
621+ text_len = ed - st
622+
623+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
624+ llm_pos_ids_list ) > 0 else 0
625+ llm_pos_ids_list .append (
626+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
627+
628+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (
629+ - 1 , llm_grid_h * llm_grid_w ).flatten ()
630+ h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (
631+ llm_grid_t , - 1 , llm_grid_w ).flatten ()
632+ w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (
633+ llm_grid_t , llm_grid_h , - 1 ).flatten ()
634+ llm_pos_ids_list .append (
635+ torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
636+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
637+
638+ if st < len (input_tokens ):
639+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
640+ llm_pos_ids_list ) > 0 else 0
641+ text_len = len (input_tokens ) - st
642+ llm_pos_ids_list .append (
643+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
644+
645+ llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
646+ mrope_position_delta = (llm_positions .max () + 1 -
647+ len (input_tokens )).item ()
648+ llm_positions = llm_positions [:, context_len :seq_len ]
649+ return llm_positions , mrope_position_delta
650+
529651 @classmethod
530652 def _ernie_get_input_positions_tensor (
531653 cls ,
0 commit comments