Skip to content

Commit e8dd159

Browse files
committed
Revert "[Kernel] unified_attention for Attention.forward (vllm-project#11967)"
This reverts commit 0f8cafe.
1 parent 872166d commit e8dd159

10 files changed

Lines changed: 45 additions & 79 deletions

vllm/attention/layer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,19 @@ def forward(
148148
kv_cache: torch.Tensor,
149149
attn_metadata: AttentionMetadata,
150150
) -> torch.Tensor:
151+
<<<<<<< HEAD
151152
if self.calculate_kv_scales and \
152153
attn_metadata.enable_kv_scales_calculation:
153154
self.calc_kv_scales(key, value)
154155
if self.use_output:
156+
=======
157+
158+
if self.use_direct_call:
159+
return self.impl.forward(query, key, value, kv_cache,
160+
attn_metadata, self._k_scale,
161+
self._v_scale)
162+
elif self.use_output:
163+
>>>>>>> parent of 0f8cafe2 ([Kernel] unified_attention for Attention.forward (#11967))
155164
output = torch.empty_like(query)
156165
hidden_size = query.size(-1)
157166
# Reshape the query, key, and value tensors.
@@ -163,19 +172,12 @@ def forward(
163172
key = key.view(-1, self.num_kv_heads, self.head_size)
164173
if value is not None:
165174
value = value.view(-1, self.num_kv_heads, self.head_size)
166-
if self.use_direct_call:
167-
unified_attention_with_output(query, key, value, output,
168-
self.layer_name)
169-
else:
170-
torch.ops.vllm.unified_attention_with_output(
171-
query, key, value, output, self.layer_name)
175+
torch.ops.vllm.unified_attention_with_output(
176+
query, key, value, output, self.layer_name)
172177
return output.view(-1, hidden_size)
173178
else:
174-
if self.use_direct_call:
175-
return unified_attention(query, key, value, self.layer_name)
176-
else:
177-
return torch.ops.vllm.unified_attention(
178-
query, key, value, self.layer_name)
179+
return torch.ops.vllm.unified_attention(query, key, value,
180+
self.layer_name)
179181

180182
def calc_kv_scales(self, key, value):
181183
self._k_scale.copy_(torch.abs(key).max() / self.k_range)

vllm/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,7 @@ def bind_kv_cache(
21832183
forward_ctx = ctx[layer_name]
21842184
assert len(forward_ctx.kv_cache) == len(kv_cache)
21852185
for ve, ve_kv_cache in enumerate(kv_cache):
2186+
assert forward_ctx.kv_cache[ve].numel() == 0
21862187
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
21872188

21882189

vllm/worker/hpu_model_runner.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from vllm.attention import AttentionMetadata, get_attn_backend
3030
from vllm.config import DeviceConfig, VllmConfig
3131
from vllm.distributed.parallel_state import get_world_group
32-
from vllm.forward_context import set_forward_context
3332
from vllm.logger import init_logger
3433
from vllm.lora.layers import LoRAMapping
3534
from vllm.lora.request import LoRARequest
@@ -42,8 +41,7 @@
4241
from vllm.sampling_params import SamplingParams
4342
from vllm.sequence import (IntermediateTensors, SequenceData,
4443
SequenceGroupMetadata)
45-
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
46-
make_tensor_with_pad)
44+
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
4745
from vllm.worker.model_runner_base import (
4846
ModelRunnerBase, ModelRunnerInputBase,
4947
_add_attn_metadata_broadcastable_dict,
@@ -1300,9 +1298,6 @@ def create_dummy_seq_group_metadata(self,
13001298
def profile_run(self) -> None:
13011299
num_layers = self.model_config.get_num_layers(self.parallel_config)
13021300
kv_caches = [None] * num_layers
1303-
bind_kv_cache(
1304-
self.vllm_config.compilation_config.static_forward_context,
1305-
[kv_caches])
13061301
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
13071302
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
13081303
self.scheduler_config.max_num_seqs)

vllm/worker/hpu_worker.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.model_executor.layers.sampler import SamplerOutput
2323
from vllm.prompt_adapter.request import PromptAdapterRequest
2424
from vllm.sequence import ExecuteModelRequest
25-
from vllm.utils import bind_kv_cache
2625
from vllm.worker.cache_engine import CacheEngine
2726
from vllm.worker.hpu_model_runner import HPUModelRunner
2827
from vllm.worker.model_runner_base import ModelRunnerBase
@@ -282,8 +281,6 @@ def _init_cache_engine(self):
282281
self.cache_engine[ve].gpu_cache
283282
for ve in range(self.parallel_config.pipeline_parallel_size)
284283
]
285-
bind_kv_cache(self.compilation_config.static_forward_context,
286-
self.hpu_cache)
287284

288285
def _warm_up_model(self) -> None:
289286
# NOTE(kzawora): We should use virtual engine index here

vllm/worker/neuron_model_runner.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from transformers_neuronx.config import GenerationConfig
99

1010
from vllm.config import VllmConfig
11-
from vllm.forward_context import set_forward_context
1211
from vllm.logger import init_logger
1312
from vllm.model_executor import SamplingMetadata
1413
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -318,15 +317,13 @@ def execute_model(
318317
raise ValueError(
319318
"NeuronModelRunner does not support multi-step execution.")
320319

321-
with set_forward_context(None, self.vllm_config, 0):
322-
hidden_states = self.model(
323-
input_ids=model_input.input_tokens,
324-
positions=model_input.input_positions,
325-
input_block_ids=model_input.input_block_ids,
326-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
327-
or {},
328-
device=self.device),
329-
)
320+
hidden_states = self.model(
321+
input_ids=model_input.input_tokens,
322+
positions=model_input.input_positions,
323+
input_block_ids=model_input.input_block_ids,
324+
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
325+
device=self.device),
326+
)
330327

331328
# Compute the logits only if the on-device sampling is turned off as
332329
# on-device sampling outputs the token ids.

vllm/worker/openvino_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from vllm.attention import get_attn_backend
99
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
1010
from vllm.config import VllmConfig
11-
from vllm.forward_context import set_forward_context
1211
from vllm.logger import init_logger
1312
from vllm.model_executor import SamplingMetadata
1413
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -355,8 +354,7 @@ def execute_model(
355354
device=self.device),
356355
}
357356

358-
with set_forward_context(attn_metadata, self.vllm_config, 0):
359-
hidden_states = model_executable(**execute_model_kwargs)
357+
hidden_states = model_executable(**execute_model_kwargs)
360358

361359
# Compute the logits.
362360
logits = self.model.compute_logits(hidden_states, sampling_metadata)

vllm/worker/openvino_worker.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from vllm.platforms import current_platform
2222
from vllm.sampling_params import SamplingParams
2323
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
24-
from vllm.utils import bind_kv_cache
2524
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
2625
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
2726

@@ -339,8 +338,6 @@ def _init_cache_engine(self) -> None:
339338
ov_device,
340339
)
341340
self.kv_cache = self.cache_engine.kv_cache
342-
bind_kv_cache(self.compilation_config.static_forward_context,
343-
[self.kv_cache])
344341
self.model_runner.block_size = self.cache_engine.block_size
345342

346343
assert self.kv_cache is not None
@@ -512,18 +509,12 @@ def model_profile_run():
512509

513510
self.model_runner.block_size = tmp_cache_config.block_size
514511

515-
bind_kv_cache(self.compilation_config.static_forward_context,
516-
profiling_cache_engine.kv_cache)
517512
# Run the model with the dummy inputs.
518513
self.model_runner.execute_model(seqs,
519514
profiling_cache_engine.kv_cache)
520515

521-
# Explicitly revert bind_kv_cache and delete temporary KV cache
522-
# manager to free KV cache when real inputs will be passed to OV
523-
bind_kv_cache(self.compilation_config.static_forward_context, [[
524-
torch.tensor([])
525-
for _ in range(len(profiling_cache_engine.kv_cache))
526-
]])
516+
# explicitly delete temporary KV cache manager to free KV cache
517+
# when real inputs will be passed to OV
527518
del profiling_cache_engine
528519

529520
logger.info(

vllm/worker/tpu_model_runner.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from vllm.attention import AttentionMetadata, get_attn_backend
1515
from vllm.config import VllmConfig
16-
from vllm.forward_context import set_forward_context
1716
from vllm.logger import init_logger
1817
from vllm.model_executor.layers.sampler import SamplerOutput
1918
from vllm.model_executor.model_loader import get_model
@@ -272,9 +271,8 @@ def _dummy_run(
272271
torch._dynamo.mark_dynamic(t, 0)
273272
torch._dynamo.mark_dynamic(p, 0)
274273
# Dummy run.
275-
with set_forward_context(attn_metadata, self.vllm_config, 0):
276-
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
277-
p, num_samples, kv_caches)
274+
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
275+
num_samples, kv_caches)
278276

279277
def warmup_model(
280278
self,
@@ -673,13 +671,10 @@ def execute_model(
673671
input_lens = model_input.input_lens[i:i + 1].to(self.device)
674672
t = model_input.t[i:i + 1].to(self.device)
675673
p = model_input.p[i:i + 1].to(self.device)
676-
with set_forward_context(model_input.attn_metadata,
677-
self.vllm_config,
678-
model_input.virtual_engine):
679-
output_token_ids = self.model(token_ids, position_ids,
680-
attn_metadata, input_lens, t,
681-
p, model_input.num_samples,
682-
kv_caches)
674+
output_token_ids = self.model(token_ids, position_ids,
675+
attn_metadata, input_lens, t, p,
676+
model_input.num_samples,
677+
kv_caches)
683678
next_token_ids.append(output_token_ids[0])
684679
start_idx = end_idx
685680

@@ -724,13 +719,10 @@ def execute_model(
724719
input_lens = model_input.input_lens.to(self.device)
725720
for i in range(num_steps):
726721
slot_mapping = attn_metadata.slot_mapping
727-
with set_forward_context(model_input.attn_metadata,
728-
self.vllm_config,
729-
model_input.virtual_engine):
730-
output_token_ids = self.model(token_ids, position_ids,
731-
attn_metadata, input_lens, t,
732-
p, model_input.num_samples,
733-
kv_caches)
722+
output_token_ids = self.model(token_ids, position_ids,
723+
attn_metadata, input_lens, t, p,
724+
model_input.num_samples,
725+
kv_caches)
734726
self.cached_step_outputs.append(output_token_ids)
735727

736728
if i < num_steps - 1:

vllm/worker/tpu_worker.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.logger import init_logger
1313
from vllm.model_executor import set_random_seed
1414
from vllm.sequence import ExecuteModelRequest
15-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
15+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
1616
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
1717
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
1818
LoraNotSupportedWorkerBase, WorkerBase,
@@ -108,8 +108,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
108108
torch.tensor([], dtype=torch.float32,
109109
device=self.device))
110110
for _ in range(num_layers)]
111-
bind_kv_cache(self.compilation_config.static_forward_context,
112-
[kv_caches])
113111
self.model_runner._dummy_run(
114112
batch_size=1,
115113
seq_len=self.scheduler_config.max_num_batched_tokens,
@@ -172,8 +170,6 @@ def initialize_cache(
172170
device="cpu")
173171
cpu_v_cache = torch.zeros_like(cpu_k_cache)
174172
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
175-
bind_kv_cache(self.compilation_config.static_forward_context,
176-
[self.tpu_cache])
177173
self._warmup_model()
178174

179175
def _warmup_model(self) -> None:

vllm/worker/xpu_model_runner.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.attention import get_attn_backend
1313
from vllm.config import VllmConfig
1414
from vllm.distributed import get_pp_group
15-
from vllm.forward_context import set_forward_context
1615
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1716
from vllm.logger import init_logger
1817
from vllm.model_executor import SamplingMetadataCache
@@ -574,17 +573,15 @@ def execute_model(
574573
if (self.observability_config is not None
575574
and self.observability_config.collect_model_forward_time):
576575
model_forward_start_time = time.time()
577-
with set_forward_context(model_input.attn_metadata, self.vllm_config,
578-
model_input.virtual_engine):
579-
hidden_or_intermediate_states = model_executable(
580-
input_ids=model_input.input_tokens,
581-
positions=model_input.input_positions,
582-
kv_caches=kv_caches,
583-
attn_metadata=model_input.attn_metadata,
584-
intermediate_tensors=intermediate_tensors,
585-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
586-
or {},
587-
device=self.device))
576+
577+
hidden_or_intermediate_states = model_executable(
578+
input_ids=model_input.input_tokens,
579+
positions=model_input.input_positions,
580+
kv_caches=kv_caches,
581+
attn_metadata=model_input.attn_metadata,
582+
intermediate_tensors=intermediate_tensors,
583+
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
584+
device=self.device))
588585
# Compute the logits in the last pipeline stage.
589586
if not get_pp_group().is_last_rank:
590587
return hidden_or_intermediate_states

0 commit comments

Comments
 (0)