Skip to content

Conversation

@pisceskkk
Copy link
Contributor

@pisceskkk pisceskkk commented Nov 14, 2025

Purpose

This PR, splited from full PR #26864, adds the supports for the Prefill Context Parallelism (PCP) with GQA flashinfer, following PR #28718. For specific implementation details, please refer to the RFC #25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage.

The current implementation primarily includes the following changes:

  • Modified ModelRunner.py for PCP partitioning logic for tokens;
  • Modified flashinfer.py to adapt the FlashInfer backend for GQA to PCP.
  • Added PrefillContextParallelMetadata shared across attention backends;
  • Renamed variables and functions shared by both PCP and DCP.

Test Plan

Qwen/Qwen2.5-3B

export VLLM_ATTENTION_BACKEND='FLASHINFER'
vllm serve Qwen/Qwen2.5-3B --tensor-parallel-size 4 --decode-context-parallel-size 2 --prefill-context-parallel-size 2 --dcp-kv-cache-interleave-size 8

Test Result

gsm8k eval

tp4 17c540a

dataset version metric mode vllm-api-general-stream
gsm8kdataset - avg@5 gen 72.78

tp4 dcp2 interleave 8

dataset version metric mode vllm-api-general-stream
gsm8kdataset - avg@5 gen 72.43

tp4 pcp2 interleave 8

dataset version metric mode vllm-api-general-stream
gsm8kdataset - avg@5 gen 72.51

tp4 dcp2 pcp2 interleave 8

dataset version metric mode vllm-api-general-stream
gsm8kdataset - avg@5 gen 72.98

CC @LookAround0301 @FENP @gjc0824 @LucasWilkinson

@mergify
Copy link

mergify bot commented Nov 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pisceskkk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Prefill Context Parallelism (PCP) for GQA with flashinfer, which is a significant feature for enhancing long-sequence inference. The changes are extensive, touching configuration, parallel state management, attention backends, and the model runner. Overall, the implementation looks solid, but I've identified a few critical issues that need to be addressed. These include a duplicated command-line argument, a syntax error, a typo in a variable name, and incorrect tensor indexing, all of which could lead to runtime errors or prevent the code from running.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@pisceskkk pisceskkk force-pushed the pcp+flashinfer branch 2 times, most recently from c0f45f9 to 489b6c5 Compare November 18, 2025 09:12
@mergify mergify bot removed the needs-rebase label Nov 18, 2025
@pisceskkk pisceskkk force-pushed the pcp+flashinfer branch 2 times, most recently from 8bc261d to 58cbd8f Compare November 18, 2025 09:54
@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pisceskkk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pisceskkk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
pisceskkk and others added 2 commits November 20, 2025 09:04
Co-authored-by: QiuChunshuo <[email protected]>
Co-authored-by: FENP <[email protected]>
Co-authored-by: LookAround <[email protected]>
Co-authored-by: Jingchun Gao <[email protected]>
Co-authored-by: zhenwenqi2024 <[email protected]>
Signed-off-by: QiuChunshuo <[email protected]>
Signed-off-by: FENP <[email protected]>
Signed-off-by: LookAround <[email protected]>
Signed-off-by: Jingchun Gao <[email protected]>
Signed-off-by: zhenwenqi2024 <[email protected]>
Co-authored-by: QiuChunshuo <[email protected]>
Co-authored-by: FENP <[email protected]>
Co-authored-by: LookAround <[email protected]>
Co-authored-by: Jingchun Gao <[email protected]>
Co-authored-by: zhenwenqi2024 <[email protected]>
Signed-off-by: QiuChunshuo <[email protected]>
Signed-off-by: FENP <[email protected]>
Signed-off-by: LookAround <[email protected]>
Signed-off-by: Jingchun Gao <[email protected]>
Signed-off-by: zhenwenqi2024 <[email protected]>
Jingchun Gao added 2 commits November 23, 2025 23:04
Signed-off-by: Jingchun Gao <[email protected]>
Signed-off-by: Jingchun Gao <[email protected]>
@gjc0824 gjc0824 force-pushed the pcp+flashinfer branch 2 times, most recently from b9ed205 to d6bbe6d Compare November 24, 2025 01:59
Signed-off-by: Jingchun Gao <[email protected]>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments on #28988 which I think similarly apply here

self.is_mm_embed = self._make_buffer(max_num_tokens, dtype=torch.bool)

# Persistent buffers for Prefill Context Parallism
if self.pcp_world_size > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please separate all of this into a PCPManager or a utils file to make it more modular and easier to migrate to model runner v2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. We will make similar changes as #28988 .

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! A few more comments

dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens,
cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu,
pcp_metadata=common_attn_metadata.pcp_metadata,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if PCP + spec-decode is not yet supported should we be passing this blindly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We have removed these params to avoid misunderstanding.

]
)
else:
wrappers_to_check.append((prefill_wrapper._new_tokens, True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kinda messy imo; can you just do something like:

class BatchCPPrefillWrapper:
        @property
         def _window_left(self):
                assert self._context._window_left == self._new_tokens._window_left
                return self._context._window_left
         ...

Lests also do:

class FlashInferImpl:
       def __init__(self, ...):
             self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

and then only run the asserts if self.is_debugging_mode to avoid excessive CPU overhead on each forward pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. We have changed it.


num_actual_tokens = attn_metadata.num_actual_tokens

if self.pcp_world_size > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be moved into BatchCPPrefillWrapper.run(...)?

Copy link
Contributor

@gjc0824 gjc0824 Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. All gather & restore KV is required in both prefilling and decoding phrase so it cannot be moved into PrefillWrapper. But we have extracted it to vllm/v1/attention/backends/utils.py as a general function, which can be used by other backends.

out,
lse,
get_dcp_group(),
return_lse=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be return_lse=self.pcp_world_size > 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We have changed it for less unnecessary calculation.

kv_for_head_indices: torch.Tensor | None = None
kv_for_tail_indices: torch.Tensor | None = None
kv_for_head_indptr: torch.Tensor | None = None
kv_for_tail_indptr: torch.Tensor | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems FlashInfer specific? we shouldn't have backend specific things (or styles) in CommonAttentionMetadata

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. *_indptr params are Flashinfer specific. So we remove *_indptr params from general PrefillContextParallelMetadata and moved the computation of PrefillContextParallelMetadata to the stage of building attention wrapper.
When building attention wrapper, we compute the *_indptr params in specific flashinfer.py and then use general functions such as get_q_indices to get PrefillContextParallelMetadata code. We think this refactoring not only distinguishes the handling of backend-specific parameters but also reduces the dependency on modelrunner.

k_head = torch.index_select(key, 0, kv_for_head_indices)
v_head = torch.index_select(value, 0, kv_for_head_indices)
k_tail = torch.index_select(key, 0, kv_for_tail_indices)
v_tail = torch.index_select(value, 0, kv_for_tail_indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so many index_select seems very expensive; have you profiled this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we profiled it. Because PCP is only meaningful when the input sequence is long, we conducted performance tests in scenarios where the input sequence length is 32k.
index_select is negligible compared to the benefits brought by PCP.

**common_kwargs,
)

def _attention_with_head_and_tail(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be moved into the BatchCPPrefillWrapper?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We have moved it into BatchCPPrefillWrapper.run().

Comment on lines 310 to 330
if return_lse:
output_head, lse_head = output_head
output_tail, lse_tail = output_tail
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0),
0,
q_full_indices,
)
lse = torch.index_select(
torch.cat([lse_head, lse_tail], dim=0),
0,
q_full_indices,
)
return output, lse
else:
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0),
0,
q_full_indices,
)
return output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if return_lse:
output_head, lse_head = output_head
output_tail, lse_tail = output_tail
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0),
0,
q_full_indices,
)
lse = torch.index_select(
torch.cat([lse_head, lse_tail], dim=0),
0,
q_full_indices,
)
return output, lse
else:
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0),
0,
q_full_indices,
)
return output
if return_lse:
output_head, lse_head = output_head
output_tail, lse_tail = output_tail
lse = torch.index_select(
torch.cat([lse_head, lse_tail], dim=0),
0,
q_full_indices,
)
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0),
0,
q_full_indices,
)
return output if not return_lse else (output, lse)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. We have changed it.

output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
)
lse_context = lse_context.transpose(0, 1).contiguous()
if self.pcp_world_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we forgot to add the calculation between kv_cache_permute and prefill_query here?
Then something like cp_lse_ag_out_rs, cp_lse_ag_out_ar and merge_attn_states correct the output?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. Current PCP on flashinfer temporarily does not support chunk prefilling and prefix caching so the kv context can be ignored. Subsequent PRs will add further support.

@mergify
Copy link

mergify bot commented Nov 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pisceskkk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 25, 2025
prefill_query_across_dcp = get_dcp_group().all_gather(
prefill_query.contiguous(), dim=1
)
output_context_tmp, lse_context_tmp = self._context.run(
Copy link
Contributor

@Livinfly Livinfly Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we store the input keys and values in the cache via reshape_and_cache_flash in FlashInferImpl.forward, this _context calculate between kv_cache_permute and prefill_query_across_dcp will lead to leak / recompute the input keys and values ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Similar reasons as last comment.

Jingchun Gao added 2 commits November 25, 2025 21:07
Signed-off-by: Jingchun Gao <[email protected]>
workspace_buffer, get_kv_cache_layout()
)
pin_memory = is_pin_memory_available()
self.pcp_q_indptr_cpu = torch.zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering — is this wrapper expected to be instantiated only once? I'm concerned about potential repeated pinned memory allocation across instances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to my knowledge, these wrappers are initialized only once. Subsequent inference processes will only invoke the plan and run functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

5 participants