Skip to content

Commit 0163070

Browse files
Merge pull request #4 from robertgshaw2-redhat/rob-changes
Rob changes
2 parents e2ecc14 + 8b3f606 commit 0163070

File tree

5 files changed

+21
-19
lines changed

5 files changed

+21
-19
lines changed

examples/offline_inference/disaggrated-prefill-v1/decode_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
1818

1919
llm = LLM(
20-
model="meta-llama/Llama-3.1-8B-Instruct",
21-
enforce_eager=True,
20+
model="meta-llama/Llama-3.2-1B-Instruct",
21+
enforce_eager=False,
2222
gpu_memory_utilization=0.8,
2323
kv_transfer_config=KVTransferConfig.from_cli(
2424
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'

examples/offline_inference/disaggrated-prefill-v1/prefill_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
1616

17-
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
18-
enforce_eager=True,
17+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
18+
enforce_eager=False,
1919
gpu_memory_utilization=0.8,
2020
kv_transfer_config=KVTransferConfig.from_cli(
2121
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
rm -rf local_storage/
22
rm output.txt
33

4-
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
5-
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
4+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=5 python3 prefill_example.py
5+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=5 python3 decode_example.py

vllm/attention/layer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,6 @@ def forward(
181181
context using
182182
`vllm.forward_context.get_forward_context().attn_metadata`.
183183
"""
184-
185-
# KVConnector: start async saving kvs to connector
186-
# to the layers KV cache before running attention.
187-
wait_for_kv_layer_from_connector(self.layer_name)
188-
189184
if self.calculate_kv_scales:
190185
attn_metadata = get_forward_context().attn_metadata
191186
if attn_metadata.enable_kv_scales_calculation:
@@ -236,10 +231,6 @@ def forward(
236231
output = torch.ops.vllm.unified_attention(
237232
query, key, value, self.layer_name)
238233

239-
# KVConnector: start saving kvs to the connector.
240-
# NOTE: forward_context completion will block until
241-
# this operation is completed.
242-
maybe_save_kv_layer_to_connector(self.layer_name, self.kv_cache)
243234
return output
244235

245236
def calc_kv_scales(self, query, key, value):
@@ -358,7 +349,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
358349

359350
def maybe_save_kv_layer_to_connector(
360351
layer_name: str,
361-
kv_cache: List[torch.Tensor],
352+
kv_cache_layer: List[torch.Tensor],
362353
):
363354
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
364355
return
@@ -370,7 +361,6 @@ def maybe_save_kv_layer_to_connector(
370361
if attn_metadata is None:
371362
return
372363

373-
kv_cache_layer = kv_cache[forward_context.virtual_engine]
374364
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
375365

376366

@@ -380,11 +370,17 @@ def unified_attention(
380370
value: torch.Tensor,
381371
layer_name: str,
382372
) -> torch.Tensor:
373+
wait_for_kv_layer_from_connector(layer_name)
374+
383375
forward_context: ForwardContext = get_forward_context()
384376
attn_metadata = forward_context.attn_metadata
385377
self = forward_context.no_compile_layers[layer_name]
386378
kv_cache = self.kv_cache[forward_context.virtual_engine]
387-
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
379+
output = self.impl.forward(self, query, key, value, kv_cache,
380+
attn_metadata)
381+
382+
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
383+
return output
388384

389385

390386
def unified_attention_fake(
@@ -412,6 +408,7 @@ def unified_attention_with_output(
412408
output: torch.Tensor,
413409
layer_name: str,
414410
) -> None:
411+
wait_for_kv_layer_from_connector(layer_name)
415412
forward_context: ForwardContext = get_forward_context()
416413
attn_metadata = forward_context.attn_metadata
417414
self = forward_context.no_compile_layers[layer_name]
@@ -424,6 +421,8 @@ def unified_attention_with_output(
424421
attn_metadata,
425422
output=output)
426423

424+
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
425+
427426

428427
def unified_attention_with_output_fake(
429428
query: torch.Tensor,

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def extract_kv_from_layer(
196196
197197
Assume the shape of the layer is (2, num_pages, page_size, xxx).
198198
"""
199+
# TODO(rob): make this compatible with MLA.
200+
201+
assert layer.shape[0] == 2
199202
num_pages, page_size = layer.shape[1], layer.shape[2]
200203
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
201204
...]
@@ -208,7 +211,7 @@ def extract_kv_from_layer(
208211
layer_name, request.token_ids)
209212
kv_cache = extract_kv_from_layer(kv_layer,
210213
request.slot_mapping)
211-
tensors = {"kv_cache": kv_cache.cpu().detach()}
214+
tensors = {"kv_cache": kv_cache.detach().cpu()}
212215
safetensors.torch.save_file(tensors, filename)
213216

214217
def wait_for_save(self):

0 commit comments

Comments
 (0)