Skip to content

Conversation

@heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Nov 7, 2025

Purpose

Some steps in model initialization only depends on layer attributes and doesn't depend on kv cache config. Move them to load_model so that initialize_kv_cache can focus on kv_cache + attention backend initialization

This PR also starts to change some functions in gpu model runner into pure functions so that they can be reused by model runner v2 in the future.

Note that test_kv_sharing_fast_prefill fails on main and is marked as optional. This PR raises the same error on this test.

Split from #27935

Test Plan

kv sharing:
Run basic.py with:

llm = LLM(
        model="google/gemma-3n-E2B-it",
        enforce_eager=True,
        kv_sharing_fast_prefill=True,
    )

encoder only:
pytest -vs tests/entrypoints/pooling/llm/test_embedding.py::test_pooling_params

Test Result

Generated Outputs:
------------------------------------------------------------
Prompt:    'Hello, my name is'
Output:    " Alex, and I'm a freelance graphic designer. I'm passionate about"
------------------------------------------------------------
Prompt:    'The president of the United States is'
Output:    ' currently taking a trip to Europe. He is visiting several countries, including France,'
------------------------------------------------------------
Prompt:    'The capital of France is'
Output:    ' Paris.\n\nThis is a true statement.\n'
------------------------------------------------------------
Prompt:    'The future of AI is'
Output:    " a complex and fascinating topic. Here's a breakdown of key trends, potential"
------------------------------------------------------------

pytest -vs tests/entrypoints/pooling/llm/test_embedding.py::test_pooling_params

test passed


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Chen Zhang <[email protected]>
@mergify mergify bot added the v1 label Nov 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a good step towards cleaning up the GPUModelRunner by moving initialization logic that doesn't depend on KVCacheConfig from initialize_kv_cache to load_model. The introduction of pure functions in utils.py is also a positive change for future code reuse. The refactoring appears to be correct and well-executed. I have one suggestion regarding code duplication to further improve the cleanup.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like get_attn_backend_cls is only used by _check_and_update_cudagraph_mode

        attention_backends = set(
            get_attn_backend_cls(
                self.vllm_config, self.kv_sharing_fast_prefill_eligible_layers
            ).values()
        )

Maybe we should use it in initialize_attn_backend too; i.e.

    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
        assert len(self.attn_groups) == 0, "Attention backends are already initialized"

        class AttentionGroupKey(NamedTuple):
            attn_backend: type[AttentionBackend]
            kv_cache_spec: KVCacheSpec

        attn_backends_dict = get_attn_backend_cls(
                self.vllm_config, self.kv_sharing_fast_prefill_eligible_layers
        )

        def get_attn_backends_for_group(
            kv_cache_group_spec: KVCacheGroupSpec,
        ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
            ...
            for layer_name in kv_cache_group_spec.layer_names:
                attn_backend = attn_backends_dict[layer_name]

then we can get rid of the duplicate

if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
    attn_backend = create_fast_prefill_custom_backend(
        "FastPrefill",
        attn_backend,
    )

(or we should make it just return set if its only get used by _check_and_update_cudagraph_mode)

)


def get_attn_backend_cls(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we name this get_attn_backend_clss or get_attn_backend_cls_dict currently the name kinda implies it only returns one backend class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion. Updated.

@heheda12345
Copy link
Collaborator Author

heheda12345 commented Nov 7, 2025

@LucasWilkinson you can check #27935 for the final version of my plan. I'll use get_attn_backend_cls in more places (including initialize_attn_backend). It will be done in a future PR.

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Copy link
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving these steps into load_model() lgtm, but I definitely can't guarantee I haven't missed something subtle!

(
self.shared_kv_cache_layers,
self.kv_sharing_fast_prefill_eligible_layers,
self.kv_sharing_fast_prefill_logits_indices,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already initialized in the constructor, you forgot to remove it from there?

But leaving it in the constructor seems fine? That would also mean the device arg can be removed from utils.kv_sharing() and make it only return layers which is a nice simplification

for layer_name, attn_module in attn_layers.items():
attn_backend = attn_module.get_attn_backend()
if layer_name in kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this gets created again later in initialize_attn_backend() ?

Can this be avoided? e.g. can it be created elsewhere earlier so get_layers_from_vllm_config() returns it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this is probably a duplicate of @LucasWilkinson comment

)
self.runner_only_attn_layers.update(
get_runner_only_attn_layers(self.vllm_config)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of the initialization in the constructor and the assertion here, and just do

self.runner_only_attn_layers = get_runner_only_attn_layers(self.vllm_config)

check_ubatch_thresholds,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.utils import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought ... what belongs in utils, and what belongs in the model runner? Are you putting these in utils so they can be re-used by other model runners? Is that a good refactoring goal in general - move as much code as possible into utils?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

This PR also starts to change some functions in gpu model runner into pure functions so that they can be reused by model runner v2 in the future.

I guess my preference would be to keep the purpose of the PR clean - move some steps into load_model() in this PR and do some more complete "prepare for model runner v2" refactoring in a separate PR. It's hard to judge whether these functions are a positive refactoring move in the context of this PR

Not a strong objection though 🤷

@mergify
Copy link

mergify bot commented Nov 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants