Skip to content

Commit 9606c71

Browse files
authored
Revert #7509 (#7887)
1 parent 64cc644 commit 9606c71

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def _get_decode_wrapper(self):
113113
self.runner.parallel_config))
114114
num_kv_heads = self.runner.model_config.get_num_kv_heads(
115115
self.runner.parallel_config)
116-
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
117-
(1, 2, 4, 8)
116+
use_tensor_cores = num_qo_heads // num_kv_heads > 4
118117
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
119118
self._get_workspace_buffer(),
120119
"NHD",
@@ -172,8 +171,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
172171
self.runner.parallel_config))
173172
num_kv_heads = self.runner.model_config.get_num_kv_heads(
174173
self.runner.parallel_config)
175-
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
176-
(1, 2, 4, 8)
174+
use_tensor_cores = num_qo_heads // num_kv_heads > 4
177175
self._graph_decode_wrapper = \
178176
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
179177
self._graph_decode_workspace_buffer, _indptr_buffer,

0 commit comments

Comments
 (0)