-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Integrate fused Mixtral MoE with Marlin kernels #7079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate fused Mixtral MoE with Marlin kernels #7079
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
|
/unready |
Refactoring for maintainability
|
@dsikka I've added some |
| expert_id: int, | ||
| is_gptq: bool = False, | ||
| ): | ||
| if is_gptq: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd want to use the weight loading functionality already present.
| "MistralForCausalLM": ("llama", "LlamaForCausalLM"), | ||
| "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), | ||
| "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), | ||
| "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'd want mixtral_quant by default
| gate_down_up = [ | ||
| ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name | ||
| ] | ||
| return ([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leverage what already exists?
This PR's functionality has been implemented through PRs #8217 and #8032. I'm closing it.
Reimplement quantized Mixtral to combine Marlin kernels with fused MoE.
This PR rewrites the Mixtral model to run a modified Marlin kernel that takes advantage of
fused_moefunctionality.The C++ code takes in all expert data and
topk_idstensor. It runs a kernel to computesorted_idsoffsets related to each expert, and then feeds them to the Marlin kernels. The Marlin kernels are run multiple times per each expert, using current expert number to figure out the current position insidesorted_idsand the number of tokens to process in each particular call. The values ofsorted_idsare then used to indirectly access the rows of input/outputA/Ctensors. If the the rows of inputAare identical for each oftopkexperts that access them (first MMM of fused MoE), tensorAconsists ofM x Kelements, with each row being accessedtopktimes by the relevant experts. Otherwise (second MMM of fused MoE),Aconsists ofM x topk x Kelements, with each row being accessed once.Unit testing:
End-to-end testing:
Run
offline_inference.pywithSonnet benchmark results (no act order, 4-bit):
Sonnet benchmark results (with act order, 8-bit):