Skip to content

Conversation

@dsikka
Copy link
Contributor

@dsikka dsikka commented Sep 5, 2024

Summary

  • Add GPTQ Marlin MoE Support; marlin MoE kernels currently support int4
  • Update/add optional testing for large MoE models for GPTQ and llm-compressor

Co-authored by @ElizaWszola, from Neural Magic

@github-actions
Copy link

github-actions bot commented Sep 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@dsikka dsikka marked this pull request as ready for review September 6, 2024 15:17
@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 6, 2024
@dsikka dsikka requested a review from mgoin September 9, 2024 17:38
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to run a quick test on it myself, but this looks good to land for 4bit support!

@mgoin
Copy link
Member

mgoin commented Sep 9, 2024

Performance looks great!

python benchmarks/benchmark_latency.py --model mistralai/Mixtral-8x7B-Instruct-v0.1 --tensor-parallel-size 2 --input-len 128 --output-len 512 --batch-size 1 --num-iters-warmup 2 --num-iters 10
Avg latency: 5.879553547129035 seconds

python benchmarks/benchmark_latency.py --model nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized --tensor-parallel-size 2 --input-len 128 --output-len 512 --batch-size 1 --num-iters-warmup 2 --num-iters 10
Avg latency: 4.726654114946723 seconds

python benchmarks/benchmark_latency.py --model TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ --tensor-parallel-size 2 --input-len 128 --output-len 512 --batch-size 1 --num-iters-warmup 2 --num-iters 10
Avg latency: 4.787863119132817 seconds

Before this PR, GPTQ Mixtral would be much slower:

python benchmarks/benchmark_latency.py --model TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ --tensor-parallel-size 2 --input-len 128 --output-len 512 --batch-size 1 --num-iters-warmup 2 --num-iters 10
Avg latency: 8.206900223530829 seconds

@mgoin mgoin merged commit 6cd5e5b into vllm-project:main Sep 10, 2024
@fengyang95
Copy link

Does this support deepseek-v2?

@xiaoqi35
Copy link

Thanks !
That's an important feature for deepseek-v2. Quantized deepseek-v2 models need FuseMoE that supports int4 quantization.
So Awq quantization is also OK?

dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 12, 2024
@ElizaWszola
Copy link
Contributor

Sonnet benchmark results (no act order, 4-bit):

// 4-bit quantized moe without act order
llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ")
TTFT TPOT
vLLM_ TTFT, Mixtral 8x7B 4-bit vLLM_ TPOT, Mixtral 8x7B 4-bit-2
vLLM_ TTFT, Mixtral 8x7B 4-bit-2 vLLM_ TPOT, Mixtral 8x7B 4-bit-3
vLLM_ TTFT, Mixtral 8x7B 4-bit-3 vLLM_ TPOT, Mixtral 8x7B 4-bit

@ElizaWszola
Copy link
Contributor

ElizaWszola commented Sep 18, 2024

Thanks ! That's an important feature for deepseek-v2. Quantized deepseek-v2 models need FuseMoE that supports int4 quantization. So Awq quantization is also OK?

AWQ is currently in our future work scope!

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
@mgoin mgoin mentioned this pull request Feb 17, 2025
3 tasks
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants