Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
949c914
MoE init
cakeng Jan 27, 2025
ba0549e
EP config integrations
cakeng Jan 28, 2025
7d7285d
Weight loading
cakeng Jan 29, 2025
991d126
Added EP info to moe_align_block_size
cakeng Jan 29, 2025
df154b6
Bugs
cakeng Jan 29, 2025
3d700c4
Working EP+TP Prototype
cakeng Jan 29, 2025
2d05161
Removed debugging print statements
cakeng Jan 30, 2025
60284b5
Fixes
cakeng Jan 30, 2025
9cdb728
Merge branch 'main' into moe
cakeng Jan 30, 2025
03b8afb
Fused MoE Kernel fixes, Errors on CudaGraph Capture
cakeng Jan 30, 2025
cdb252d
Merge branch 'vllm-project:main' into moe
cakeng Jan 30, 2025
f7dcd7b
Expert mapping fixes, num_experts does not need to be divisible by EP…
cakeng Jan 31, 2025
b3e00f5
Merge branch 'main' into moe
cakeng Feb 4, 2025
d8cb2b3
Merge branch 'main' into moe
cakeng Feb 4, 2025
837a0fb
Merge branch 'main' into moe
cakeng Feb 7, 2025
cefcef6
Integrating to DeepSeekV3
cakeng Feb 7, 2025
e550752
EP correctness on DeepSeekV3 checked
cakeng Feb 8, 2025
ea5fbdc
Moved expert_parallel_size argument to expertimental_expert_parallel_…
cakeng Feb 8, 2025
8485a9a
Added environment variable VLLM_TEST_ENABLE_EP to turn on EP, default…
cakeng Feb 11, 2025
4c7ba48
Merge branch 'main' into moe
cakeng Feb 11, 2025
5cadaa0
Errors running test scripts
cakeng Feb 13, 2025
83a7190
Fixed FusedMoE apply base method
cakeng Feb 13, 2025
7102ae7
Added test_expert_parallel.py and modified test_moe.py to include EP …
cakeng Feb 14, 2025
1951aa7
Pre-commit
cakeng Feb 14, 2025
b9c8c18
Merge branch 'main' into moe
cakeng Feb 14, 2025
1254ba6
Clean debug print statements
cakeng Feb 14, 2025
c8a6c64
Fixed FusedMoE apply function signatures
cakeng Feb 14, 2025
b721ee7
Fixed remaining FusedMoE apply function signatures
cakeng Feb 14, 2025
2a9993b
Removed debugging print statement.
cakeng Feb 14, 2025
56f9828
Update parallel_state.py
cakeng Feb 14, 2025
5e53fbb
Update parallel_state.py
cakeng Feb 14, 2025
182423a
Reduce diffs
cakeng Feb 15, 2025
5ced495
Updated comments from Lucas and Tyler
cakeng Feb 19, 2025
9bf31be
Made TP rank setting more explicit during the FusedMoE weight loading
cakeng Feb 19, 2025
88ce3a9
Merge branch 'main' into moe
tlrmchlsmth Feb 20, 2025
99e2d98
test_moe.py fixes
cakeng Feb 21, 2025
ee3f981
Merge branch 'main' into moe
cakeng Feb 21, 2025
5df1e06
Merge branch 'main' into moe
cakeng Feb 21, 2025
c5a21c6
fix test_expert_parallel.py env_variable
cakeng Feb 22, 2025
accb533
fix test_expert_parallel.py test configurations
cakeng Feb 22, 2025
14864f7
Merge branch 'main' into moe
cakeng Feb 23, 2025
a0f2c21
Fix pre-commit
cakeng Feb 23, 2025
5554d35
Revert FusedMoE num_experts variable name change and add deepseekV2 t…
cakeng Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions tests/distributed/test_expert_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import List, Literal, NamedTuple, Optional

import pytest

from vllm.config import TaskOption
from vllm.logger import init_logger

from ..utils import compare_two_settings, fork_new_process_for_each_test

logger = init_logger("test_pipeline_parallel")


class ParallelSetup(NamedTuple):
tp_size: int
eager_mode: bool
chunked_prefill: bool


class EPTestOptions(NamedTuple):
trust_remote_code: bool
tokenizer_mode: Optional[str]
load_format: Optional[str] = None
hf_overrides: Optional[str] = None


@dataclass
class EPTestSettings:
parallel_setups: List[ParallelSetup]
distributed_backends: List[str]
task: TaskOption
test_options: EPTestOptions

@staticmethod
def detailed(
*,
tp_base: int = 2,
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
):
return EPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=2 * tp_base,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=2 * tp_base,
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
task=task,
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
)

@staticmethod
def fast(
*,
tp_base: int = 2,
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
):
return EPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
eager_mode=False,
chunked_prefill=True),
],
distributed_backends=["ray"],
task=task,
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
)

def iter_params(self, model_name: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
self.task, opts)


# NOTE: You can adjust tp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

# yapf: disable
TEST_MODELS = {
# "ai21labs/Jamba-v0.1": EPTestSettings.fast(trust_remote_code=True),
# "deepseek-ai/deepseek-llm-7b-chat": EPTestSettings.fast(
# trust_remote_code=True),
# "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(
# trust_remote_code=True),
"mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4),
}


def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
test_options: EPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate", "encode"],
):
(
tp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
(
trust_remote_code,
tokenizer_mode,
load_format,
hf_overrides,
) = test_options

if num_gpus_available < tp_size:
pytest.skip(f"Need at least {tp_size} GPUs")

common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if task != "auto":
common_args.extend(["--task", task])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])

ep_env = {
"VLLM_TEST_EP_PARALLEL": "0",
}

ep_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
]

# compare without pipeline parallelism
# NOTE: use mp backend for TP
# PP tests might involve multiple nodes, and ray might
# schedule all workers in a node other than the head node,
# which can cause the test to fail.
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
"mp",
]

try:
compare_two_settings(model_name,
ep_args,
tp_args,
ep_env,
method=method)
except Exception:
raise


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
[
params for model_name, settings in TEST_MODELS.items()
for params in settings.iter_params(model_name)
],
)
@fork_new_process_for_each_test
def test_ep(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
test_options: EPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
test_options,
num_gpus_available,
method="generate")
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,19 @@ def compare_all_settings(model: str,
})

if method == "generate":
print(f"Testing generate {model=} {prompt=} {token_ids=}")
results += _test_completion(client, model, prompt, token_ids)
elif method == "generate_close":
print(f"Testing generate_close {model=} {prompt=}")
results += _test_completion_close(client, model, prompt)
elif method == "generate_with_image":
print(f"Testing generate_with_image {model=} {prompt=}")
results += _test_image_text(
client, model,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
)
elif method == "encode":
print(f"Testing encode {model=} {prompt=}")
results += _test_embeddings(client, model, prompt)
else:
raise ValueError(f"Unknown method: {method}")
Expand Down
18 changes: 18 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,23 @@ def verify_with_parallel_config(
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

if envs.VLLM_TEST_ENABLE_EP:
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(self.hf_text_config, name, 0)
if num_experts > 0:
break
if num_experts < 1:
raise ValueError(
"Number of experts in the model must be greater than 0 "
"when using expert parallelism.")

pipeline_parallel_size = parallel_config.pipeline_parallel_size
if pipeline_parallel_size > 1:
architectures = getattr(self.hf_config, "architectures", [])
Expand Down Expand Up @@ -1356,6 +1373,7 @@ def compute_hash(self):
return hashlib.sha256(str(factors).encode()).hexdigest()

def __post_init__(self) -> None:

self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

Expand Down
6 changes: 4 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,8 @@ def get_tp_group() -> GroupCoordinator:

_PP: Optional[GroupCoordinator] = None

_EP_SIZE: Optional[int] = None


def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
Expand Down Expand Up @@ -1004,7 +1006,7 @@ def initialize_model_parallel(
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.

Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
Expand Down Expand Up @@ -1094,7 +1096,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size)
return

assert (
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_TEST_ENABLE_EP: bool = False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
Expand Down Expand Up @@ -552,6 +553,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
),

# If set, vLLM will use the experimental expert parallel implementation on
# the FusedMoE layer, using tensor parallelism size as expert parallelism
# size.
"VLLM_TEST_ENABLE_EP":
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),

# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
Expand Down
Loading
Loading