-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[P/D][V1] KV Connector API V1 #15960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
6a12481
34bea75
20ef2ac
430e402
300ddac
519dd3e
7e0695b
553f416
b18bd8f
55d1b5b
c50e620
b22fe38
da257aa
9751e0b
2b77bcd
5cbd434
4c6a93e
d8ec5a6
fcd2dc9
1f9c252
7350244
1586d58
e2ecc14
5accb53
31d807e
a73721a
00df670
4ebcc3e
da019df
90e8c53
8b3f606
0163070
de1e487
48c2eb2
e72e5e4
7833645
1881aa5
eca7a49
b0629bd
7766ca5
7b64acb
b1310fd
689379e
62e1421
5145566
20decdf
fc58dd5
25c9592
40e5d81
e64f745
74af233
7c31e29
7f57f3c
05349a5
8e1eadc
54e1491
9c4159c
1d8415d
406d6bf
3a24897
c6c4368
5dff6e9
4afa50e
09be260
3f7844d
d44f699
329f2e7
72041ca
f9f87f2
33f6e60
db28310
3701b5d
deb1323
be789bf
a3e5762
a03d707
f696000
1d85e63
521ed14
6709943
44ea156
8180101
c3a2cc6
5273e24
913325f
75c24d3
17a3618
b4bd117
01caf61
d8549cb
b362ef1
485b22e
78d523e
4c38138
7af6ce2
1ad993b
3a08dda
e49874d
dd7969a
e1f130e
611b782
f6b8bff
9609115
7ce3bd6
c3f38d7
6dfda44
81d008a
6d35884
79fe730
ad18a3b
edefdff
c1a1169
1b8ec0b
ff4b98f
17b61fb
ac0660d
ecfb4ea
abdddf0
8695d96
7b5ba2c
6be9cf9
5363ed0
247195d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import KVTransferConfig | ||
|
|
||
| # Read prompts from output.txt | ||
| prompts = [] | ||
| try: | ||
| with open("output.txt") as f: | ||
| for line in f: | ||
| prompts.append(line.strip()) | ||
| print(f"Loaded {len(prompts)} prompts from output.txt") | ||
| except FileNotFoundError: | ||
| print("Error: output.txt file not found") | ||
| exit(-1) | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) | ||
|
|
||
| llm = LLM( | ||
| model="meta-llama/Llama-3.1-8B-Instruct", | ||
| enforce_eager=True, | ||
| gpu_memory_utilization=0.8, | ||
| kv_transfer_config=KVTransferConfig.from_cli( | ||
| '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' | ||
| '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' | ||
| )) #, max_model_len=2048, max_num_batched_tokens=2048) | ||
|
|
||
| # 1ST generation (prefill instance) | ||
| outputs = llm.generate(prompts, sampling_params) | ||
|
|
||
| new_prompts = [] | ||
|
||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| new_prompts.append(prompt + generated_text) | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import KVTransferConfig | ||
|
|
||
| context = "Hi " * 1000 | ||
| context2 = "Hey " * 500 | ||
| prompts = [ | ||
| context + "Hello, my name is", | ||
| context + "The capital of France is", | ||
| context2 + "Your name is", | ||
| context2 + "The capital of China is", | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) | ||
|
|
||
| llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", | ||
| enforce_eager=True, | ||
| gpu_memory_utilization=0.8, | ||
| kv_transfer_config=KVTransferConfig.from_cli( | ||
| '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' | ||
| '"kv_connector_extra_config": ' | ||
| '{"shared_storage_path": "local_storage"}}') | ||
robertgshaw2-redhat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) #, max_model_len=2048, max_num_batched_tokens=2048) | ||
|
|
||
| # 1ST generation (prefill instance) | ||
| outputs = llm.generate( | ||
| prompts, | ||
| sampling_params, | ||
| ) | ||
|
|
||
| new_prompts = [] | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| new_prompts.append(prompt + generated_text) | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
|
||
| # Write new_prompts to output.txt | ||
| with open("output.txt", "w") as f: | ||
| for prompt in new_prompts: | ||
| f.write(prompt + "\n") | ||
| print(f"Saved {len(new_prompts)} prompts to output.txt") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| rm -rf local_storage/ | ||
| rm output.txt | ||
|
|
||
| VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py | ||
| VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |||||||||||||||||
| from vllm.attention import AttentionType | ||||||||||||||||||
| from vllm.attention.selector import backend_name_to_enum, get_attn_backend | ||||||||||||||||||
| from vllm.config import CacheConfig, get_current_vllm_config | ||||||||||||||||||
| from vllm.distributed import (get_kv_transfer_group, has_kv_transfer_group, | ||||||||||||||||||
| is_v1_kv_transfer_group) | ||||||||||||||||||
| from vllm.forward_context import ForwardContext, get_forward_context | ||||||||||||||||||
| from vllm.model_executor.layers.linear import UnquantizedLinearMethod | ||||||||||||||||||
| from vllm.model_executor.layers.quantization.base_config import ( | ||||||||||||||||||
|
|
@@ -179,6 +181,11 @@ def forward( | |||||||||||||||||
| context using | ||||||||||||||||||
| `vllm.forward_context.get_forward_context().attn_metadata`. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| # KVConnector: start async saving kvs to connector | ||||||||||||||||||
| # to the layers KV cache before running attention. | ||||||||||||||||||
|
||||||||||||||||||
| wait_for_kv_layer_from_connector(self.layer_name) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.calculate_kv_scales: | ||||||||||||||||||
| attn_metadata = get_forward_context().attn_metadata | ||||||||||||||||||
| if attn_metadata.enable_kv_scales_calculation: | ||||||||||||||||||
|
|
@@ -217,18 +224,24 @@ def forward( | |||||||||||||||||
| else: | ||||||||||||||||||
| torch.ops.vllm.unified_attention_with_output( | ||||||||||||||||||
| query, key, value, output, self.layer_name) | ||||||||||||||||||
| return output.view(-1, hidden_size) | ||||||||||||||||||
| output = output.view(-1, hidden_size) | ||||||||||||||||||
| else: | ||||||||||||||||||
| if self.use_direct_call: | ||||||||||||||||||
| forward_context = get_forward_context() | ||||||||||||||||||
| attn_metadata = forward_context.attn_metadata | ||||||||||||||||||
| self_kv_cache = self.kv_cache[forward_context.virtual_engine] | ||||||||||||||||||
| return self.impl.forward(self, query, key, value, | ||||||||||||||||||
| self_kv_cache, attn_metadata) | ||||||||||||||||||
| output = self.impl.forward(self, query, key, value, | ||||||||||||||||||
| self_kv_cache, attn_metadata) | ||||||||||||||||||
| else: | ||||||||||||||||||
| return torch.ops.vllm.unified_attention( | ||||||||||||||||||
| output = torch.ops.vllm.unified_attention( | ||||||||||||||||||
| query, key, value, self.layer_name) | ||||||||||||||||||
|
|
||||||||||||||||||
| # KVConnector: start saving kvs to the connector. | ||||||||||||||||||
| # NOTE: forward_context completion will block until | ||||||||||||||||||
| # this operation is completed. | ||||||||||||||||||
| maybe_save_kv_layer_to_connector(self.layer_name, self.kv_cache) | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
| def calc_kv_scales(self, query, key, value): | ||||||||||||||||||
| self._q_scale.copy_(torch.abs(query).max() / self.q_range) | ||||||||||||||||||
| self._k_scale.copy_(torch.abs(key).max() / self.k_range) | ||||||||||||||||||
|
|
@@ -329,6 +342,38 @@ def forward( | |||||||||||||||||
| return out.reshape(bsz, q_len, -1) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def wait_for_kv_layer_from_connector(layer_name: str): | ||||||||||||||||||
| if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| connector = get_kv_transfer_group() | ||||||||||||||||||
|
|
||||||||||||||||||
| forward_context: ForwardContext = get_forward_context() | ||||||||||||||||||
| attn_metadata = forward_context.attn_metadata | ||||||||||||||||||
| if attn_metadata is None: | ||||||||||||||||||
| return | ||||||||||||||||||
robertgshaw2-redhat marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
|
|
||||||||||||||||||
| connector.wait_for_layer_load(layer_name) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def maybe_save_kv_layer_to_connector( | ||||||||||||||||||
| layer_name: str, | ||||||||||||||||||
| kv_cache: List[torch.Tensor], | ||||||||||||||||||
| ): | ||||||||||||||||||
| if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| connector = get_kv_transfer_group() | ||||||||||||||||||
|
|
||||||||||||||||||
| forward_context: ForwardContext = get_forward_context() | ||||||||||||||||||
| attn_metadata = forward_context.attn_metadata | ||||||||||||||||||
| if attn_metadata is None: | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| kv_cache_layer = kv_cache[forward_context.virtual_engine] | ||||||||||||||||||
| connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ApostaC, @robertgshaw2-redhat It seems that this function is the culprit to not be able to use the KVConnector without
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you run into an issue? I tried with
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well it's a silent error. This what I see:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-redhat Is it related to the "no_compile_layers" issue? I'll take a look at this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to take advantage of piecewise CUDA graphs, ref: Lines 3389 to 3396 in e1a2c69
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im making progress on this but still stuck. |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def unified_attention( | ||||||||||||||||||
| query: torch.Tensor, | ||||||||||||||||||
| key: torch.Tensor, | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # yapf: disable | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
| KVConnectorBase_V1, KVConnectorRole) | ||
|
|
||
| # yapf: enable | ||
|
|
||
| __all__ = [ | ||
| "KVConnectorRole", | ||
| "KVConnectorBase_V1", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-redhat Just realized there is a typo in this folder's name
disaggrated-prefill-v1->disaggregated-prefill-v1