@@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
224224 query_start_loc = query_start_loc_host ,
225225 device = self .runner .device ,
226226 data_type = kv_cache_dtype ,
227+ q_data_type = self .runner .model_config .dtype ,
227228 use_cuda_graph = True ,
228229 decode_wrapper = self ._graph_decode_wrapper ,
229230 prefill_wrapper = None )
@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
292293 page_size : Optional [int ] = None
293294 # The data type of the paged kv cache
294295 data_type : torch .dtype = None
296+ # The data type of the query
297+ q_data_type : torch .dtype = None
295298 device : torch .device = torch .device ("cuda" )
296299 is_profile_run : bool = False
297300
@@ -353,7 +356,10 @@ def begin_forward(self):
353356 self .page_size ,
354357 # Disable flashinfer's pos encoding and use vllm's rope.
355358 pos_encoding_mode = "NONE" ,
356- data_type = self .data_type )
359+ # kv-cache data type.
360+ data_type = self .data_type ,
361+ # query data type.
362+ q_data_type = self .q_data_type )
357363
358364 def asdict_zerocopy (self ,
359365 skip_fields : Optional [Set [str ]] = None
@@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
617623 query_start_loc = query_start_loc ,
618624 device = device ,
619625 data_type = kv_cache_dtype ,
626+ q_data_type = self .runner .model_config .dtype ,
620627 use_cuda_graph = use_captured_graph ,
621628 is_profile_run = self .is_profile_run )
622629
0 commit comments