-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Feature] Support Prefill Context Parallel (PCP) for GQA flashinfer #28723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Prefill Context Parallelism (PCP) for GQA with flashinfer, which is a significant feature for enhancing long-sequence inference. The changes are extensive, touching configuration, parallel state management, attention backends, and the model runner. Overall, the implementation looks solid, but I've identified a few critical issues that need to be addressed. These include a duplicated command-line argument, a syntax error, a typo in a variable name, and incorrect tensor indexing, all of which could lead to runtime errors or prevent the code from running.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
c0f45f9 to
489b6c5
Compare
8bc261d to
58cbd8f
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
68e3ea6 to
ba1b05c
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]>
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]>
44f658e to
1cac317
Compare
Signed-off-by: Jingchun Gao <[email protected]>
Signed-off-by: Jingchun Gao <[email protected]>
b9ed205 to
d6bbe6d
Compare
Signed-off-by: Jingchun Gao <[email protected]>
d6bbe6d to
df36e76
Compare
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments on #28988 which I think similarly apply here
| self.is_mm_embed = self._make_buffer(max_num_tokens, dtype=torch.bool) | ||
|
|
||
| # Persistent buffers for Prefill Context Parallism | ||
| if self.pcp_world_size > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please separate all of this into a PCPManager or a utils file to make it more modular and easier to migrate to model runner v2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. We will make similar changes as #28988 .
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! A few more comments
vllm/v1/spec_decode/eagle.py
Outdated
| dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, | ||
| cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, | ||
| cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, | ||
| pcp_metadata=common_attn_metadata.pcp_metadata, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if PCP + spec-decode is not yet supported should we be passing this blindly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We have removed these params to avoid misunderstanding.
| ] | ||
| ) | ||
| else: | ||
| wrappers_to_check.append((prefill_wrapper._new_tokens, True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is kinda messy imo; can you just do something like:
class BatchCPPrefillWrapper:
@property
def _window_left(self):
assert self._context._window_left == self._new_tokens._window_left
return self._context._window_left
...
Lests also do:
class FlashInferImpl:
def __init__(self, ...):
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
and then only run the asserts if self.is_debugging_mode to avoid excessive CPU overhead on each forward pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. We have changed it.
|
|
||
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
|
||
| if self.pcp_world_size > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be moved into BatchCPPrefillWrapper.run(...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. All gather & restore KV is required in both prefilling and decoding phrase so it cannot be moved into PrefillWrapper. But we have extracted it to vllm/v1/attention/backends/utils.py as a general function, which can be used by other backends.
| out, | ||
| lse, | ||
| get_dcp_group(), | ||
| return_lse=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be return_lse=self.pcp_world_size > 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We have changed it for less unnecessary calculation.
vllm/v1/attention/backends/utils.py
Outdated
| kv_for_head_indices: torch.Tensor | None = None | ||
| kv_for_tail_indices: torch.Tensor | None = None | ||
| kv_for_head_indptr: torch.Tensor | None = None | ||
| kv_for_tail_indptr: torch.Tensor | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems FlashInfer specific? we shouldn't have backend specific things (or styles) in CommonAttentionMetadata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. *_indptr params are Flashinfer specific. So we remove *_indptr params from general PrefillContextParallelMetadata and moved the computation of PrefillContextParallelMetadata to the stage of building attention wrapper.
When building attention wrapper, we compute the *_indptr params in specific flashinfer.py and then use general functions such as get_q_indices to get PrefillContextParallelMetadata code. We think this refactoring not only distinguishes the handling of backend-specific parameters but also reduces the dependency on modelrunner.
| k_head = torch.index_select(key, 0, kv_for_head_indices) | ||
| v_head = torch.index_select(value, 0, kv_for_head_indices) | ||
| k_tail = torch.index_select(key, 0, kv_for_tail_indices) | ||
| v_tail = torch.index_select(value, 0, kv_for_tail_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so many index_select seems very expensive; have you profiled this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we profiled it. Because PCP is only meaningful when the input sequence is long, we conducted performance tests in scenarios where the input sequence length is 32k.
index_select is negligible compared to the benefits brought by PCP.
| **common_kwargs, | ||
| ) | ||
|
|
||
| def _attention_with_head_and_tail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be moved into the BatchCPPrefillWrapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We have moved it into BatchCPPrefillWrapper.run().
| if return_lse: | ||
| output_head, lse_head = output_head | ||
| output_tail, lse_tail = output_tail | ||
| output = torch.index_select( | ||
| torch.cat([output_head, output_tail], dim=0), | ||
| 0, | ||
| q_full_indices, | ||
| ) | ||
| lse = torch.index_select( | ||
| torch.cat([lse_head, lse_tail], dim=0), | ||
| 0, | ||
| q_full_indices, | ||
| ) | ||
| return output, lse | ||
| else: | ||
| output = torch.index_select( | ||
| torch.cat([output_head, output_tail], dim=0), | ||
| 0, | ||
| q_full_indices, | ||
| ) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if return_lse: | |
| output_head, lse_head = output_head | |
| output_tail, lse_tail = output_tail | |
| output = torch.index_select( | |
| torch.cat([output_head, output_tail], dim=0), | |
| 0, | |
| q_full_indices, | |
| ) | |
| lse = torch.index_select( | |
| torch.cat([lse_head, lse_tail], dim=0), | |
| 0, | |
| q_full_indices, | |
| ) | |
| return output, lse | |
| else: | |
| output = torch.index_select( | |
| torch.cat([output_head, output_tail], dim=0), | |
| 0, | |
| q_full_indices, | |
| ) | |
| return output | |
| if return_lse: | |
| output_head, lse_head = output_head | |
| output_tail, lse_tail = output_tail | |
| lse = torch.index_select( | |
| torch.cat([lse_head, lse_tail], dim=0), | |
| 0, | |
| q_full_indices, | |
| ) | |
| output = torch.index_select( | |
| torch.cat([output_head, output_tail], dim=0), | |
| 0, | |
| q_full_indices, | |
| ) | |
| return output if not return_lse else (output, lse) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. We have changed it.
| output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True | ||
| ) | ||
| lse_context = lse_context.transpose(0, 1).contiguous() | ||
| if self.pcp_world_size > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we forgot to add the calculation between kv_cache_permute and prefill_query here?
Then something like cp_lse_ag_out_rs, cp_lse_ag_out_ar and merge_attn_states correct the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. Current PCP on flashinfer temporarily does not support chunk prefilling and prefix caching so the kv context can be ignored. Subsequent PRs will add further support.
|
This pull request has merge conflicts that must be resolved before it can be |
| prefill_query_across_dcp = get_dcp_group().all_gather( | ||
| prefill_query.contiguous(), dim=1 | ||
| ) | ||
| output_context_tmp, lse_context_tmp = self._context.run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we store the input keys and values in the cache via reshape_and_cache_flash in FlashInferImpl.forward, this _context calculate between kv_cache_permute and prefill_query_across_dcp will lead to leak / recompute the input keys and values ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Similar reasons as last comment.
Signed-off-by: Jingchun Gao <[email protected]>
Signed-off-by: Jingchun Gao <[email protected]>
28e2d1a to
07e78b1
Compare
| workspace_buffer, get_kv_cache_layout() | ||
| ) | ||
| pin_memory = is_pin_memory_available() | ||
| self.pcp_q_indptr_cpu = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering — is this wrapper expected to be instantiated only once? I'm concerned about potential repeated pinned memory allocation across instances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, to my knowledge, these wrappers are initialized only once. Subsequent inference processes will only invoke the plan and run functions.
Purpose
This PR, splited from full PR #26864, adds the supports for the Prefill Context Parallelism (PCP) with GQA flashinfer, following PR #28718. For specific implementation details, please refer to the RFC #25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage.
The current implementation primarily includes the following changes:
ModelRunner.pyfor PCP partitioning logic for tokens;flashinfer.pyto adapt the FlashInfer backend for GQA to PCP.PrefillContextParallelMetadatashared across attention backends;Test Plan
Qwen/Qwen2.5-3B
Test Result
gsm8k eval
tp4 17c540a
tp4 dcp2 interleave 8
tp4 pcp2 interleave 8
tp4 dcp2 pcp2 interleave 8
CC @LookAround0301 @FENP @gjc0824 @LucasWilkinson