-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Expert Parallelism (EP) Support for DeepSeek Models #12583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
949c914
MoE init
cakeng ba0549e
EP config integrations
cakeng 7d7285d
Weight loading
cakeng 991d126
Added EP info to moe_align_block_size
cakeng df154b6
Bugs
cakeng 3d700c4
Working EP+TP Prototype
cakeng 2d05161
Removed debugging print statements
cakeng 60284b5
Fixes
cakeng 9cdb728
Merge branch 'main' into moe
cakeng 03b8afb
Fused MoE Kernel fixes, Errors on CudaGraph Capture
cakeng cdb252d
Merge branch 'vllm-project:main' into moe
cakeng f7dcd7b
Expert mapping fixes, num_experts does not need to be divisible by EP…
cakeng b3e00f5
Merge branch 'main' into moe
cakeng d8cb2b3
Merge branch 'main' into moe
cakeng 837a0fb
Merge branch 'main' into moe
cakeng cefcef6
Integrating to DeepSeekV3
cakeng e550752
EP correctness on DeepSeekV3 checked
cakeng ea5fbdc
Moved expert_parallel_size argument to expertimental_expert_parallel_…
cakeng 8485a9a
Added environment variable VLLM_TEST_ENABLE_EP to turn on EP, default…
cakeng 4c7ba48
Merge branch 'main' into moe
cakeng 5cadaa0
Errors running test scripts
cakeng 83a7190
Fixed FusedMoE apply base method
cakeng 7102ae7
Added test_expert_parallel.py and modified test_moe.py to include EP …
cakeng 1951aa7
Pre-commit
cakeng b9c8c18
Merge branch 'main' into moe
cakeng 1254ba6
Clean debug print statements
cakeng c8a6c64
Fixed FusedMoE apply function signatures
cakeng b721ee7
Fixed remaining FusedMoE apply function signatures
cakeng 2a9993b
Removed debugging print statement.
cakeng 56f9828
Update parallel_state.py
cakeng 5e53fbb
Update parallel_state.py
cakeng 182423a
Reduce diffs
cakeng 5ced495
Updated comments from Lucas and Tyler
cakeng 9bf31be
Made TP rank setting more explicit during the FusedMoE weight loading
cakeng 88ce3a9
Merge branch 'main' into moe
tlrmchlsmth 99e2d98
test_moe.py fixes
cakeng ee3f981
Merge branch 'main' into moe
cakeng 5df1e06
Merge branch 'main' into moe
cakeng c5a21c6
fix test_expert_parallel.py env_variable
cakeng accb533
fix test_expert_parallel.py test configurations
cakeng 14864f7
Merge branch 'main' into moe
cakeng a0f2c21
Fix pre-commit
cakeng 5554d35
Revert FusedMoE num_experts variable name change and add deepseekV2 t…
cakeng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,229 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Literal, NamedTuple, Optional | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.config import TaskOption | ||
| from vllm.logger import init_logger | ||
|
|
||
| from ..utils import compare_two_settings, fork_new_process_for_each_test | ||
|
|
||
| logger = init_logger("test_pipeline_parallel") | ||
|
|
||
|
|
||
| class ParallelSetup(NamedTuple): | ||
| tp_size: int | ||
| eager_mode: bool | ||
| chunked_prefill: bool | ||
|
|
||
|
|
||
| class EPTestOptions(NamedTuple): | ||
| trust_remote_code: bool | ||
| tokenizer_mode: Optional[str] | ||
| load_format: Optional[str] = None | ||
| hf_overrides: Optional[str] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class EPTestSettings: | ||
| parallel_setups: List[ParallelSetup] | ||
| distributed_backends: List[str] | ||
| task: TaskOption | ||
| test_options: EPTestOptions | ||
|
|
||
| @staticmethod | ||
| def detailed( | ||
| *, | ||
| tp_base: int = 2, | ||
| task: TaskOption = "auto", | ||
| trust_remote_code: bool = False, | ||
| tokenizer_mode: Optional[str] = None, | ||
| load_format: Optional[str] = None, | ||
| hf_overrides: Optional[str] = None, | ||
| ): | ||
| return EPTestSettings( | ||
| parallel_setups=[ | ||
| ParallelSetup(tp_size=tp_base, | ||
| eager_mode=False, | ||
| chunked_prefill=False), | ||
| ParallelSetup(tp_size=tp_base, | ||
| eager_mode=False, | ||
| chunked_prefill=True), | ||
| ParallelSetup(tp_size=tp_base, | ||
| eager_mode=True, | ||
| chunked_prefill=False), | ||
| ParallelSetup(tp_size=2 * tp_base, | ||
| eager_mode=False, | ||
| chunked_prefill=True), | ||
| ParallelSetup(tp_size=2 * tp_base, | ||
| eager_mode=True, | ||
| chunked_prefill=False), | ||
| ], | ||
| distributed_backends=["mp", "ray"], | ||
| task=task, | ||
| test_options=EPTestOptions(trust_remote_code=trust_remote_code, | ||
| tokenizer_mode=tokenizer_mode, | ||
| load_format=load_format, | ||
| hf_overrides=hf_overrides), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def fast( | ||
| *, | ||
| tp_base: int = 2, | ||
| task: TaskOption = "auto", | ||
| trust_remote_code: bool = False, | ||
| tokenizer_mode: Optional[str] = None, | ||
| load_format: Optional[str] = None, | ||
| hf_overrides: Optional[str] = None, | ||
| ): | ||
| return EPTestSettings( | ||
| parallel_setups=[ | ||
| ParallelSetup(tp_size=tp_base, | ||
| eager_mode=True, | ||
| chunked_prefill=False), | ||
| ParallelSetup(tp_size=tp_base, | ||
| eager_mode=False, | ||
| chunked_prefill=True), | ||
| ], | ||
| distributed_backends=["ray"], | ||
| task=task, | ||
| test_options=EPTestOptions(trust_remote_code=trust_remote_code, | ||
| tokenizer_mode=tokenizer_mode, | ||
| load_format=load_format, | ||
| hf_overrides=hf_overrides), | ||
| ) | ||
|
|
||
| def iter_params(self, model_name: str): | ||
| opts = self.test_options | ||
|
|
||
| for parallel_setup in self.parallel_setups: | ||
| for distributed_backend in self.distributed_backends: | ||
| yield (model_name, parallel_setup, distributed_backend, | ||
| self.task, opts) | ||
|
|
||
|
|
||
| # NOTE: You can adjust tp_base locally to fit the model in GPU | ||
| # The values displayed here are only a rough indicator of the size of the model | ||
|
|
||
| # yapf: disable | ||
| TEST_MODELS = { | ||
| # "ai21labs/Jamba-v0.1": EPTestSettings.fast(trust_remote_code=True), | ||
| # "deepseek-ai/deepseek-llm-7b-chat": EPTestSettings.fast( | ||
| # trust_remote_code=True), | ||
| # "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast( | ||
| # trust_remote_code=True), | ||
| "mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4), | ||
| } | ||
|
|
||
|
|
||
| def _compare_tp( | ||
| model_name: str, | ||
| parallel_setup: ParallelSetup, | ||
| distributed_backend: str, | ||
| task: TaskOption, | ||
| test_options: EPTestOptions, | ||
| num_gpus_available: int, | ||
| *, | ||
| method: Literal["generate", "encode"], | ||
| ): | ||
| ( | ||
| tp_size, | ||
| eager_mode, | ||
| chunked_prefill, | ||
| ) = parallel_setup | ||
| ( | ||
| trust_remote_code, | ||
| tokenizer_mode, | ||
| load_format, | ||
| hf_overrides, | ||
| ) = test_options | ||
|
|
||
| if num_gpus_available < tp_size: | ||
| pytest.skip(f"Need at least {tp_size} GPUs") | ||
|
|
||
| common_args = [ | ||
| # use half precision for speed and memory savings in CI environment | ||
| "--dtype", | ||
| "float16", | ||
| "--max-model-len", | ||
| "2048", | ||
| "--max-num-seqs", | ||
| "8", | ||
| ] | ||
| if chunked_prefill: | ||
| common_args.append("--enable-chunked-prefill") | ||
| if eager_mode: | ||
| common_args.append("--enforce-eager") | ||
| if task != "auto": | ||
| common_args.extend(["--task", task]) | ||
| if trust_remote_code: | ||
| common_args.append("--trust-remote-code") | ||
| if tokenizer_mode: | ||
| common_args.extend(["--tokenizer-mode", tokenizer_mode]) | ||
| if load_format: | ||
| common_args.extend(["--load-format", load_format]) | ||
| if hf_overrides: | ||
| common_args.extend(["--hf-overrides", hf_overrides]) | ||
|
|
||
| ep_env = { | ||
| "VLLM_TEST_EP_PARALLEL": "0", | ||
| } | ||
|
|
||
| ep_args = [ | ||
| *common_args, | ||
| "--tensor-parallel-size", | ||
| str(tp_size), | ||
| "--distributed-executor-backend", | ||
| distributed_backend, | ||
| ] | ||
|
|
||
| # compare without pipeline parallelism | ||
| # NOTE: use mp backend for TP | ||
| # PP tests might involve multiple nodes, and ray might | ||
| # schedule all workers in a node other than the head node, | ||
| # which can cause the test to fail. | ||
| tp_args = [ | ||
| *common_args, | ||
| "--tensor-parallel-size", | ||
| str(tp_size), | ||
| "--distributed-executor-backend", | ||
| "mp", | ||
| ] | ||
|
|
||
| try: | ||
| compare_two_settings(model_name, | ||
| ep_args, | ||
| tp_args, | ||
| ep_env, | ||
| method=method) | ||
| except Exception: | ||
| raise | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("model_name", "parallel_setup", "distributed_backend", "task", | ||
| "test_options"), | ||
| [ | ||
| params for model_name, settings in TEST_MODELS.items() | ||
| for params in settings.iter_params(model_name) | ||
| ], | ||
| ) | ||
| @fork_new_process_for_each_test | ||
| def test_ep( | ||
| model_name: str, | ||
| parallel_setup: ParallelSetup, | ||
| distributed_backend: str, | ||
| task: TaskOption, | ||
| test_options: EPTestOptions, | ||
| num_gpus_available, | ||
| ): | ||
| _compare_tp(model_name, | ||
| parallel_setup, | ||
| distributed_backend, | ||
| task, | ||
| test_options, | ||
| num_gpus_available, | ||
| method="generate") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.