Skip to content

Conversation

@youkaichao
Copy link
Member

@youkaichao youkaichao commented Jun 3, 2025

#17211 explores the possibility of compiling multiple models in the same process, i.e. both the main model and the eagle head model. However, it does this by extending the compilation cache directory in a tricky way. In addition, that integration can be problematic when we compile for specific shapes, the env vars are being set multiple times, with the latter overriding the previous value:

os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache

This PR re-organizes the cache directory structure, so that the same vLLM instances will use the same TORCHINDUCTOR_CACHE_DIR and TRITON_CACHE_DIR, but just different storage for vllm_compile_cache.py etc.

It also reads the prefix automatically, and I think this would be helpful for future vision encoder compilation.

the current structure after running examples/offline_inference/eagle.py:

~/.cache/vllm/torch_compile_cache/762970b379/rank_0_0
  - inductor_cache
  - triton_cache
  - backbone
    - computation_graph.py
    - transformed_code.py
    - vllm_compile_cache.py
  - eagle_head
    - computation_graph.py
    - transformed_code.py
    - vllm_compile_cache.py

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

cc @zou3519 @houseroad

@mergify mergify bot added the v1 label Jun 3, 2025
@github-actions
Copy link

github-actions bot commented Jun 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: youkaichao <[email protected]>
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the idea is pretty neat. Left two more comments.

Also please ensure the cache can be still loaded appropriately. :-)

def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since prefix is only used to caculate the base_cache_dir, why not use pass in the base_cache_dir instead of passing in prefix?

def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
def set_current_vllm_config(vllm_config: VllmConfig,
check_compile=False,
prefix: Optional[str] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a bit comment to explain the prefix meaning?

self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
Copy link
Collaborator

@zou3519 zou3519 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we name this something like "set_compile_region" or "set_model_component" (see the other comment)? That would make it clearer that this is 1:1 with a fullgraph torch.compile region

def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is technically a subdirectory (or a suffix to the path), not a prefix. I was prototyping something like this locally and I called this the "model_component", but up to you

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

Copy link
Contributor

@luyuzhe111 luyuzhe111 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @youkaichao, I wonder if we can simply the def configure_post_pass(self) method here? I had to make some edits to make things work here but maybe they are not necessary anymore? Thanks!

@vadiklyutiy
Copy link
Collaborator

Why did you decide to keep inductor_cache and triton_cache on upper level, not inside backbone and eagle_head? Current organization introduce additional base_cache_dir and make code a little bit more complicated.

@vadiklyutiy
Copy link
Collaborator

May we pass model_tag via argument of support_torch_compile? It might be a little bit better than use global.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 12, 2025
@houseroad houseroad merged commit d70bc7c into vllm-project:main Jun 13, 2025
73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants