Skip to content

Conversation

@bnellnm
Copy link
Collaborator

@bnellnm bnellnm commented Jul 29, 2024

Miscellaneous changes to support torch.compile in vLLM.

  • Add meta functions for various ops to prevent torch.compile graph breaks.
  • Use string schemas for all dispatched op registrations. (the function pointer API is only used for ops that are registered for all keys or that do not take Tensor arguments)
  • Fix some type mismatches in the quantization code
  • In the aqlm kernel/code, change codebook_partition_sizes into a list instead of a Tensor since it is allocated on the CPU.
  • Bump up dynamo cache limits due to the amount of recompilation in cuda graph warmup.
  • Add torch.library.opcheck tests for ops that had "unit" tests and are opcheck-able.

Note: opcheck does not seem to work with torch.float8_e4m3fn. It complains that mul_cuda is not supported for that type.


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@bnellnm bnellnm changed the title Add meta functions for ops to prevent graph breaks [Kernel][Misc] Add meta functions for ops to prevent graph breaks Jul 29, 2024
@bnellnm bnellnm force-pushed the fix-graph-breaks branch 2 times, most recently from 3713ec3 to 63c42c7 Compare August 5, 2024 22:03
@bnellnm bnellnm marked this pull request as ready for review August 6, 2024 03:09
@bnellnm
Copy link
Collaborator Author

bnellnm commented Aug 6, 2024

/ready
/torch.compile

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 6, 2024
@bnellnm bnellnm force-pushed the fix-graph-breaks branch 2 times, most recently from 97a11f1 to f015312 Compare August 6, 2024 20:15
@youkaichao youkaichao self-assigned this Aug 6, 2024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the rationale here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a different set of failures when fullgraph=True. Not just graph breaks but actual dynamo errors. Locally, I am switching between full/not-full constantly to test things.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you send me the specific case where fullgraph fails?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably more but these all fail with fullgraph=True and pass with fullgraph=False

FAILED tests/models/test_gguf.py::test_models[1-5-32-half-model0] - torch._dynamo.exc.Unsupported: hasattr TensorVariable shard_size
FAILED tests/models/test_gguf.py::test_models[1-5-32-half-model1] - torch._dynamo.exc.InternalTorchDynamoError: 'SymNodeVariable' object has no attribute 'value'
FAILED tests/models/test_gguf.py::test_models[1-5-32-half-model2] - torch._dynamo.exc.Unsupported: hasattr TensorVariable shard_size
FAILED tests/models/test_gguf.py::test_models[1-5-32-half-model3] - torch._dynamo.exc.Unsupported: hasattr TensorVariable shard_size

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe torch.library.define is generally not necessary. We should use high-level APIs listed in https://pytorch.org/docs/main/library.html .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using custom_op here but it couldn't deal with the window_size tuple so I ended up sticking with define

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the longer term answer is to move the registration into the flash attn module itself once it is pulled into the main vllm build process.

Comment on lines +72 to +80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with this for now, but after I port the technique to skip dynamo guard evaluation overhead, this should not be necessary anymore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this?

Copy link
Collaborator Author

@bnellnm bnellnm Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch had a problem with the set() here. It was one of the things fixed in the nightly version but not in 2.4. Using set also didn't seem to be particularly useful either so I left it as a range.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pr and sorry for keeping you waiting for so long!

the registration for the quantization kernels makes sense to me. but I don't understand the rest changes.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the great pr!

we can further discuss if we need to fix the graph breaks ourselves or pytorch team will fix them.

@youkaichao youkaichao merged commit 73202db into vllm-project:main Sep 11, 2024
@youkaichao youkaichao deleted the fix-graph-breaks branch September 11, 2024 19:52
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 12, 2024
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018

@torch.library.register_fake("_C::gptq_marlin_24_gemm")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This again breaks compatibility with torch < 2.4

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed torch.compile

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants