2525# limitations under the License.
2626"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
2727from collections .abc import Iterable , Mapping
28- from functools import partial
28+ from functools import lru_cache , partial
2929from typing import Callable , Literal , Optional , TypedDict , Union
3030
3131import torch
@@ -478,8 +478,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None:
478478 super ().__init__ ()
479479 self .dim = dim
480480 self .theta = theta
481- inv_freq = 1.0 / (theta
482- ** ( torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
481+ inv_freq = 1.0 / (theta ** (
482+ torch .arange (0 , dim , 2 , dtype = torch .float , device = 'cpu' ) / dim ))
483483 self .register_buffer ("inv_freq" , inv_freq , persistent = False )
484484 self ._seq_len_cached = 0
485485 self ._freqs_cached = None
@@ -520,7 +520,7 @@ def __init__(
520520 self .hidden_size = vision_config .hidden_size
521521 self .num_heads = vision_config .num_heads
522522
523- # args for get_window_index
523+ # args for get_window_index_thw
524524 self .window_size = vision_config .window_size
525525 self .patch_size = vision_config .patch_size
526526 self .spatial_merge_size = vision_config .spatial_merge_size
@@ -567,65 +567,71 @@ def dtype(self) -> torch.dtype:
567567 def device (self ) -> torch .device :
568568 return self .patch_embed .proj .weight .device
569569
570- def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
571- pos_ids = []
572- for t , h , w in grid_thw :
573- hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
574- wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
575- hpos_ids = hpos_ids .reshape (
576- h // self .spatial_merge_size ,
577- self .spatial_merge_size ,
578- w // self .spatial_merge_size ,
579- self .spatial_merge_size ,
580- ).permute (0 , 2 , 1 , 3 ).flatten ()
581- wpos_ids = wpos_ids .reshape (
582- h // self .spatial_merge_size ,
583- self .spatial_merge_size ,
584- w // self .spatial_merge_size ,
585- self .spatial_merge_size ,
586- ).permute (0 , 2 , 1 , 3 ).flatten ()
587- pos_ids .append (
588- torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
589- pos_ids = torch .cat (pos_ids , dim = 0 )
590- max_grid_size = grid_thw [:, 1 :].max ()
591- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
570+ def rotary_pos_emb_thw (self , t , h , w ):
571+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
572+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
573+ hpos_ids = hpos_ids .reshape (
574+ h // self .spatial_merge_size ,
575+ self .spatial_merge_size ,
576+ w // self .spatial_merge_size ,
577+ self .spatial_merge_size ,
578+ ).permute (0 , 2 , 1 , 3 ).flatten ()
579+ wpos_ids = wpos_ids .reshape (
580+ h // self .spatial_merge_size ,
581+ self .spatial_merge_size ,
582+ w // self .spatial_merge_size ,
583+ self .spatial_merge_size ,
584+ ).permute (0 , 2 , 1 , 3 ).flatten ()
585+ pos_ids = torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 )
586+ max_size = max (h , w )
587+ rotary_pos_emb_full = self .rotary_pos_emb (max_size )
592588 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
589+ rotary_pos_emb = rotary_pos_emb .reshape (
590+ rotary_pos_emb .shape [0 ] // self .spatial_merge_unit ,
591+ self .spatial_merge_unit , - 1 )
592+
593593 return rotary_pos_emb
594594
595- def get_window_index (self , grid_thw ):
596- window_index : list = []
597- cu_window_seqlens : list = [0 ]
598- window_index_id = 0
595+ def get_window_index_thw (self , grid_t , grid_h , grid_w ):
599596 vit_merger_window_size = (self .window_size //
600597 self .spatial_merge_size // self .patch_size )
601598
602- for grid_t , grid_h , grid_w in grid_thw :
603- llm_grid_h = grid_h // self .spatial_merge_size
604- llm_grid_w = grid_w // self .spatial_merge_size
605- index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
606- grid_t , llm_grid_h , llm_grid_w )
607- pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
608- pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
609- num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
610- num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
611- index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
612- index_padded = index_padded .reshape (grid_t , num_windows_h ,
613- vit_merger_window_size ,
614- num_windows_w ,
615- vit_merger_window_size )
616- index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
617- grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
618- vit_merger_window_size )
619- seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
620- index_padded = index_padded .reshape (- 1 )
621- index_new = index_padded [index_padded != - 100 ]
622- window_index .append (index_new + window_index_id )
623- cu_seqlens_tmp = seqlens .cumsum (
624- 0 ) * self .spatial_merge_unit + cu_window_seqlens [- 1 ]
625- cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
626- window_index_id += (grid_t * llm_grid_h * llm_grid_w ).item ()
627- window_index = torch .cat (window_index , dim = 0 )
628- return window_index , cu_window_seqlens
599+ llm_grid_h = grid_h // self .spatial_merge_size
600+ llm_grid_w = grid_w // self .spatial_merge_size
601+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
602+ grid_t , llm_grid_h , llm_grid_w )
603+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
604+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
605+ num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
606+ num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
607+ index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
608+ index_padded = index_padded .reshape (grid_t , num_windows_h ,
609+ vit_merger_window_size ,
610+ num_windows_w ,
611+ vit_merger_window_size )
612+ index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
613+ grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
614+ vit_merger_window_size )
615+ seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
616+ index_padded = index_padded .reshape (- 1 )
617+ index_new = index_padded [index_padded != - 100 ]
618+ cu_seqlens_tmp = seqlens .cumsum (0 ) * self .spatial_merge_unit
619+ cu_seqlens_tmp = cu_seqlens_tmp .to (dtype = torch .int32 )
620+ cu_seqlens_tmp = torch .unique_consecutive (cu_seqlens_tmp )
621+
622+ return index_new , cu_seqlens_tmp
623+
624+ @lru_cache (maxsize = 1024 ) # noqa: B019
625+ def get_rope_by_thw (self , t , h , w ):
626+ window_index_thw , cu_seqlens_window_thw = self .get_window_index_thw (
627+ t , h , w )
628+ rotary_pos_emb_thw = self .rotary_pos_emb_thw (t , h , w )
629+ rotary_pos_emb_thw = rotary_pos_emb_thw [window_index_thw , :, :]
630+ rotary_pos_emb_thw = rotary_pos_emb_thw .flatten (start_dim = 0 , end_dim = 1 )
631+ cu_seqlens_thw = torch .repeat_interleave (
632+ torch .tensor ([h * w ], dtype = torch .int32 ), t )
633+ return (rotary_pos_emb_thw , window_index_thw , cu_seqlens_window_thw ,
634+ cu_seqlens_thw )
629635
630636 def compute_attn_mask_seqlen (
631637 self ,
@@ -641,45 +647,74 @@ def compute_attn_mask_seqlen(
641647 def forward (
642648 self ,
643649 x : torch .Tensor ,
644- grid_thw : torch . Tensor ,
650+ grid_thw : list [ list [ int ]] ,
645651 ) -> torch .Tensor :
646652 # patchify
653+ seq_len , _ = x .size ()
654+ rotary_pos_emb = []
655+ window_index : list = []
656+ cu_window_seqlens : list = [torch .tensor ([0 ], dtype = torch .int32 )]
657+ cu_seqlens : list = []
658+
647659 hidden_states = x .to (device = self .device , dtype = self .dtype )
648660 hidden_states = self .patch_embed (hidden_states )
649661
650- # compute position embedding
651- rotary_pos_emb = self .rot_pos_emb (grid_thw )
662+ window_index_id = 0
663+ cu_window_seqlens_last = 0
664+ for t , h , w in grid_thw :
665+ t , h , w = int (t ), int (h ), int (w )
666+ llm_h = h // self .spatial_merge_size
667+ llm_w = w // self .spatial_merge_size
668+
669+ (
670+ rotary_pos_emb_thw ,
671+ window_index_thw ,
672+ cu_seqlens_window_thw ,
673+ cu_seqlens_thw ,
674+ ) = self .get_rope_by_thw (t , h , w )
675+
676+ window_index .append (window_index_thw + window_index_id )
677+ window_index_id += (t * llm_h * llm_w )
678+
679+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
680+ cu_window_seqlens_last )
681+ cu_window_seqlens_last = cu_seqlens_window_thw [- 1 ]
682+ cu_window_seqlens .append (cu_seqlens_window_thw )
652683
653- # windows attention
654- window_index , cu_window_seqlens = self .get_window_index (grid_thw )
655- cu_window_seqlens = torch .tensor (
656- cu_window_seqlens ,
657- device = hidden_states .device ,
658- dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
684+ rotary_pos_emb .append (rotary_pos_emb_thw )
685+
686+ cu_seqlens .append (cu_seqlens_thw )
687+
688+ rotary_pos_emb = torch .cat (rotary_pos_emb )
689+ window_index = torch .cat (window_index )
690+ cu_window_seqlens = torch .cat (cu_window_seqlens )
659691 cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
660- seq_len , _ = hidden_states .size ()
661- hidden_states = hidden_states .reshape (
662- seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
663- hidden_states = hidden_states [window_index , :, :]
664- hidden_states = hidden_states .reshape (seq_len , - 1 )
665- rotary_pos_emb = rotary_pos_emb .reshape (
666- seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
667- rotary_pos_emb = rotary_pos_emb [window_index , :, :]
668- rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
669- # compute cu_seqlens
670- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
671- grid_thw [:, 0 ]).cumsum (
672- dim = 0 , dtype = torch .int32 )
692+ cu_seqlens = torch .cat (cu_seqlens )
693+ cu_seqlens = torch .cumsum (cu_seqlens , dim = 0 , dtype = torch .int32 )
673694 cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), "constant" , 0 )
674695
675696 # transformers
676- hidden_states = hidden_states .unsqueeze (1 )
677-
678697 # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
679698 max_seqlen_full , seqlens_full = self .compute_attn_mask_seqlen (
680699 cu_seqlens )
681700 max_seqlen_window , seqlens_window = self .compute_attn_mask_seqlen (
682701 cu_window_seqlens )
702+
703+ cu_seqlens = cu_seqlens .to (device = self .device , non_blocking = True )
704+ cu_window_seqlens = cu_window_seqlens .to (device = self .device ,
705+ non_blocking = True )
706+ rotary_pos_emb = rotary_pos_emb .to (device = self .device ,
707+ non_blocking = True )
708+ window_index = window_index .to (device = hidden_states .device ,
709+ non_blocking = True )
710+
711+ hidden_states = hidden_states .reshape (
712+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
713+ hidden_states = hidden_states [window_index , :, :]
714+ hidden_states = hidden_states .reshape (seq_len , - 1 )
715+
716+ hidden_states = hidden_states .unsqueeze (1 )
717+
683718 for layer_num , blk in enumerate (self .blocks ):
684719 if layer_num in self .fullatt_block_indexes :
685720 cu_seqlens_now = cu_seqlens
@@ -932,12 +967,13 @@ def _process_image_input(
932967
933968 grid_thw = image_input ["image_grid_thw" ]
934969 assert grid_thw .ndim == 2
970+ grid_thw_list = grid_thw .tolist ()
935971
936972 if image_input ["type" ] == "image_embeds" :
937973 image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
938974 else :
939975 pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
940- image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
976+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
941977
942978 # Split concatenated embeddings for each image item.
943979 merge_size = self .visual .spatial_merge_size
@@ -951,13 +987,15 @@ def _process_video_input(
951987
952988 grid_thw = video_input ["video_grid_thw" ]
953989 assert grid_thw .ndim == 2
990+ grid_thw_list = grid_thw .tolist ()
954991
955992 if video_input ["type" ] == "video_embeds" :
956993 video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
957994 else :
958995 pixel_values_videos = video_input ["pixel_values_videos" ].type (
959996 self .visual .dtype )
960- video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
997+ video_embeds = self .visual (pixel_values_videos ,
998+ grid_thw = grid_thw_list )
961999
9621000 # Split concatenated embeddings for each video item.
9631001 merge_size = self .visual .spatial_merge_size
0 commit comments