File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed
Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -113,8 +113,7 @@ def _get_decode_wrapper(self):
113113 self .runner .parallel_config ))
114114 num_kv_heads = self .runner .model_config .get_num_kv_heads (
115115 self .runner .parallel_config )
116- use_tensor_cores = (num_qo_heads // num_kv_heads ) not in \
117- (1 , 2 , 4 , 8 )
116+ use_tensor_cores = num_qo_heads // num_kv_heads > 4
118117 self ._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper (
119118 self ._get_workspace_buffer (),
120119 "NHD" ,
@@ -172,8 +171,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
172171 self .runner .parallel_config ))
173172 num_kv_heads = self .runner .model_config .get_num_kv_heads (
174173 self .runner .parallel_config )
175- use_tensor_cores = (num_qo_heads // num_kv_heads ) not in \
176- (1 , 2 , 4 , 8 )
174+ use_tensor_cores = num_qo_heads // num_kv_heads > 4
177175 self ._graph_decode_wrapper = \
178176 CUDAGraphBatchDecodeWithPagedKVCacheWrapper (
179177 self ._graph_decode_workspace_buffer , _indptr_buffer ,
You can’t perform that action at this time.
0 commit comments