|
32 | 32 | AttentionMetadata, |
33 | 33 | AttentionMetadataBuilder, |
34 | 34 | AttentionState, AttentionType) |
35 | | -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, |
| 35 | +from vllm.attention.backends.utils import (PAD_SLOT_ID, PerLayerParameters, |
| 36 | + compute_slot_mapping, |
36 | 37 | compute_slot_mapping_start_idx, |
| 38 | + infer_global_hyperparameters, |
37 | 39 | is_block_tables_empty) |
38 | | -from vllm.attention.layer import Attention |
39 | 40 | from vllm.attention.ops.paged_attn import PagedAttention |
40 | | -from vllm.config import VllmConfig, get_current_vllm_config |
| 41 | +from vllm.config import get_current_vllm_config |
41 | 42 | from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, |
42 | 43 | make_tensor_with_pad) |
43 | 44 |
|
@@ -106,72 +107,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: |
106 | 107 | raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") |
107 | 108 |
|
108 | 109 |
|
109 | | -@dataclass |
110 | | -class PerLayerParameters: |
111 | | - """ |
112 | | - Currently, FlashInfer backend only support models in which all layers share |
113 | | - the same values for the following hyperparameters. |
114 | | - """ |
115 | | - |
116 | | - window_left: int |
117 | | - logits_soft_cap: Optional[float] |
118 | | - sm_scale: float |
119 | | - |
120 | | - |
121 | | -def get_per_layer_parameters( |
122 | | - vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: |
123 | | - """ |
124 | | - Scan all attention layers and determine some hyperparameters |
125 | | - to use during `plan`. |
126 | | - """ |
127 | | - |
128 | | - layers = vllm_config.compilation_config.static_forward_context |
129 | | - per_layer_params: Dict[str, PerLayerParameters] = {} |
130 | | - |
131 | | - for key, layer in layers.items(): |
132 | | - assert isinstance(layer, Attention) |
133 | | - |
134 | | - impl = layer.impl |
135 | | - assert isinstance(impl, FlashInferImpl) |
136 | | - |
137 | | - # Infer hyperparameters from the attention layer |
138 | | - window_size = impl.sliding_window |
139 | | - window_left = window_size[0] if window_size is not None else -1 |
140 | | - logits_soft_cap = impl.logits_soft_cap |
141 | | - sm_scale = impl.scale |
142 | | - |
143 | | - per_layer_params[key] = PerLayerParameters(window_left, |
144 | | - logits_soft_cap, sm_scale) |
145 | | - |
146 | | - return per_layer_params |
147 | | - |
148 | | - |
149 | | -def infer_global_hyperparameters( |
150 | | - per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: |
151 | | - """ |
152 | | - Currently, FlashInfer backend only support models in which all layers share |
153 | | - the same values for the following hyperparameters: |
154 | | - - `window_left` |
155 | | - - `logits_soft_cap` |
156 | | - - `sm_scale` |
157 | | -
|
158 | | - So this function asserts that all layers share the same values for these |
159 | | - hyperparameters and returns the global values. |
160 | | - """ |
161 | | - |
162 | | - assert len(per_layer_params) > 0, "No attention layers found in the model." |
163 | | - |
164 | | - param_sets = list(per_layer_params.values()) |
165 | | - global_params = param_sets[0] |
166 | | - for params in param_sets: |
167 | | - assert params == global_params, ( |
168 | | - "FlashInfer backend currently only supports models in which all " |
169 | | - "layers share the same values for the following hyperparameters: " |
170 | | - "`window_left`, `logits_soft_cap`, `sm_scale`.") |
171 | | - |
172 | | - return global_params |
173 | | - |
174 | | - |
175 | 110 | class FlashInferState(AttentionState): |
176 | 111 |
|
177 | 112 | def __init__(self, runner): |
@@ -293,8 +228,8 @@ def graph_capture_get_metadata_for_batch( |
293 | 228 | batch_size + 1, |
294 | 229 | dtype=torch.int32) |
295 | 230 |
|
296 | | - global_params = infer_global_hyperparameters( |
297 | | - get_per_layer_parameters(self.vllm_config)) |
| 231 | + global_params = infer_global_hyperparameters(self.vllm_config, |
| 232 | + FlashInferImpl) |
298 | 233 |
|
299 | 234 | attn_metadata = self.runner.attn_backend.make_metadata( |
300 | 235 | num_prefills=0, |
@@ -652,7 +587,7 @@ def prepare(self): |
652 | 587 | # - `logits_soft_cap` |
653 | 588 | # - `sm_scale` |
654 | 589 | inferred_params = infer_global_hyperparameters( |
655 | | - get_per_layer_parameters(self.vllm_config)) |
| 590 | + self.vllm_config, FlashInferImpl) |
656 | 591 | self.global_hyperparameters = inferred_params |
657 | 592 | self.window_left = inferred_params.window_left |
658 | 593 | self.logits_soft_cap = inferred_params.logits_soft_cap |
|
0 commit comments