Skip to content

Commit 9c4edf4

Browse files
authored
Merge pull request #4 from LookAround0301/long_seq_tmp
Long seq tmp
2 parents 4e51fa8 + ab2cb7a commit 9c4edf4

File tree

10 files changed

+108
-17
lines changed

10 files changed

+108
-17
lines changed

vllm/config/parallel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class ParallelConfig:
4141
"""Number of pipeline parallel groups."""
4242
tensor_parallel_size: int = 1
4343
"""Number of tensor parallel groups."""
44+
context_parallel_size: int = 1
45+
"""Number of context parallel groups."""
4446
data_parallel_size: int = 1
4547
"""Number of data parallel groups. MoE layers will be sharded according to
4648
the product of the tensor parallel size and data parallel size."""
@@ -71,6 +73,8 @@ class ParallelConfig:
7173
between local data parallel ranks, but an external LB balances
7274
between vLLM nodes/replicas. Set explicitly in conjunction with
7375
--data-parallel-start-rank."""
76+
enable_sequence_parallel: bool = False
77+
"""Enable sequence parallel."""
7478
enable_expert_parallel: bool = False
7579
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
7680
enable_eplb: bool = False
@@ -238,7 +242,7 @@ def compute_hash(self):
238242

239243
def __post_init__(self) -> None:
240244
self.world_size = self.pipeline_parallel_size * \
241-
self.tensor_parallel_size
245+
self.tensor_parallel_size * self.context_parallel_size
242246

243247
if self.data_parallel_size_local > self.data_parallel_size:
244248
raise ValueError(

vllm/distributed/parallel_state.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,24 @@ def get_pp_group() -> GroupCoordinator:
928928
return _PP
929929

930930

931+
_CP: Optional[GroupCoordinator] = None
932+
933+
934+
def get_cp_group() -> GroupCoordinator:
935+
assert _CP is not None, ("context parallel group is not initialized")
936+
return _CP
937+
938+
939+
def get_context_model_parallel_world_size():
940+
"""Return world size for the tensor model parallel group."""
941+
return get_cp_group().world_size
942+
943+
944+
def get_context_model_parallel_rank():
945+
"""Return my rank for the tensor model parallel group."""
946+
return get_cp_group().rank_in_group
947+
948+
931949
@deprecated("`get_pipeline_model_parallel_group` has been replaced with "
932950
"`get_pp_group` and may be removed in v0.12. Please use "
933951
"`get_pp_group` instead.")
@@ -1034,6 +1052,7 @@ def init_distributed_environment(
10341052
def initialize_model_parallel(
10351053
tensor_model_parallel_size: int = 1,
10361054
pipeline_model_parallel_size: int = 1,
1055+
context_model_parallel_size: int = 1,
10371056
backend: Optional[str] = None,
10381057
) -> None:
10391058
"""
@@ -1082,7 +1101,7 @@ def initialize_model_parallel(
10821101
# to get group_ranks for each dimension, transpose that dimension to the
10831102
# last dimension, then reshape to 2D, then unbind the last dimension
10841103
all_ranks = torch.arange(world_size).reshape(
1085-
-1, data_parallel_size, pipeline_model_parallel_size,
1104+
-1, data_parallel_size, pipeline_model_parallel_size, context_model_parallel_size,
10861105
tensor_model_parallel_size) # noqa
10871106

10881107
# Build the tensor model-parallel groups.
@@ -1102,7 +1121,7 @@ def initialize_model_parallel(
11021121
global _PP
11031122
assert _PP is None, (
11041123
"pipeline model parallel group is already initialized")
1105-
group_ranks = all_ranks.transpose(2, 3).reshape(
1124+
group_ranks = all_ranks.transpose(2, 4).reshape(
11061125
-1, pipeline_model_parallel_size).unbind(0)
11071126
group_ranks = [x.tolist() for x in group_ranks]
11081127
_PP = init_model_parallel_group(group_ranks,
@@ -1113,7 +1132,7 @@ def initialize_model_parallel(
11131132
global _DP
11141133
assert _DP is None, ("data parallel group is already initialized")
11151134
group_ranks = all_ranks.transpose(1,
1116-
3).reshape(-1,
1135+
4).reshape(-1,
11171136
data_parallel_size).unbind(0)
11181137
group_ranks = [x.tolist() for x in group_ranks]
11191138
_DP = init_model_parallel_group(group_ranks,
@@ -1124,23 +1143,34 @@ def initialize_model_parallel(
11241143
global _EP
11251144
assert _EP is None, ("expert parallel group is already initialized")
11261145
group_ranks = all_ranks.transpose(1, 2).reshape(
1127-
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
1146+
-1, data_parallel_size * tensor_model_parallel_size * context_model_parallel_size).unbind(0)
11281147
group_ranks = [x.tolist() for x in group_ranks]
11291148
_EP = init_model_parallel_group(group_ranks,
11301149
get_world_group().local_rank,
11311150
backend,
11321151
group_name="ep")
11331152

1153+
global _CP
1154+
assert _CP is None, ("context parallel group is already initialized")
1155+
group_ranks = all_ranks.transpose(3, 4).reshape(
1156+
-1, context_model_parallel_size).unbind(0)
1157+
group_ranks = [x.tolist() for x in group_ranks]
1158+
_CP = init_model_parallel_group(group_ranks,
1159+
get_world_group().local_rank,
1160+
backend,
1161+
group_name="cp")
1162+
11341163
logger.info(
11351164
"rank %s in world size %s is assigned as "
1136-
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
1165+
"DP rank %s, PP rank %s, TP rank %s, EP rank %s, CP rank %s", rank, world_size,
11371166
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
1138-
_EP.rank_in_group)
1167+
_EP.rank_in_group, _CP.rank_in_group)
11391168

11401169

11411170
def ensure_model_parallel_initialized(
11421171
tensor_model_parallel_size: int,
11431172
pipeline_model_parallel_size: int,
1173+
context_model_parallel_size: int,
11441174
backend: Optional[str] = None,
11451175
) -> None:
11461176
"""Helper to initialize model parallel groups if they are not initialized,
@@ -1151,7 +1181,7 @@ def ensure_model_parallel_initialized(
11511181
get_world_group().device_group)
11521182
if not model_parallel_is_initialized():
11531183
initialize_model_parallel(tensor_model_parallel_size,
1154-
pipeline_model_parallel_size, backend)
1184+
pipeline_model_parallel_size, context_model_parallel_size, backend)
11551185
return
11561186

11571187
assert (
@@ -1164,6 +1194,11 @@ def ensure_model_parallel_initialized(
11641194
"pipeline parallel group already initialized, but of unexpected size. "
11651195
f"got: {pp_world_size=} vs. "
11661196
f"wanted: {pipeline_model_parallel_size=}")
1197+
cp_world_size = get_cp_group().world_size
1198+
assert (cp_world_size == context_model_parallel_size), (
1199+
"context parallel group already initialized, but of unexpected size: "
1200+
f"{cp_world_size=} vs. "
1201+
f"{context_model_parallel_size=}")
11671202

11681203

11691204
def prepare_communication_buffer_for_model(model: torch.nn.Module):
@@ -1256,6 +1291,11 @@ def destroy_model_parallel():
12561291
_EP.destroy()
12571292
_EP = None
12581293

1294+
global _CP
1295+
if _CP:
1296+
_CP.destroy()
1297+
_CP = None
1298+
12591299

12601300
def destroy_distributed_environment():
12611301
global _WORLD, _NODE_COUNT

vllm/engine/arg_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ class EngineArgs:
296296
# number of P/D disaggregation (or other disaggregation) workers
297297
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
298298
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
299+
context_parallel_size: int = ParallelConfig.context_parallel_size
299300
data_parallel_size: int = ParallelConfig.data_parallel_size
300301
data_parallel_rank: Optional[int] = None
301302
data_parallel_start_rank: Optional[int] = None
@@ -304,6 +305,7 @@ class EngineArgs:
304305
data_parallel_rpc_port: Optional[int] = None
305306
data_parallel_hybrid_lb: bool = False
306307
data_parallel_backend: str = ParallelConfig.data_parallel_backend
308+
enable_sequence_parallel: bool = ParallelConfig.enable_sequence_parallel
307309
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
308310
enable_eplb: bool = ParallelConfig.enable_eplb
309311
num_redundant_experts: int = ParallelConfig.num_redundant_experts
@@ -619,6 +621,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
619621
**parallel_kwargs["pipeline_parallel_size"])
620622
parallel_group.add_argument("--tensor-parallel-size", "-tp",
621623
**parallel_kwargs["tensor_parallel_size"])
624+
parallel_group.add_argument("--context-parallel-size", "-cp",
625+
**parallel_kwargs["context_parallel_size"])
622626
parallel_group.add_argument("--data-parallel-size", "-dp",
623627
**parallel_kwargs["data_parallel_size"])
624628
parallel_group.add_argument(
@@ -656,6 +660,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
656660
parallel_group.add_argument(
657661
"--data-parallel-hybrid-lb",
658662
**parallel_kwargs["data_parallel_hybrid_lb"])
663+
parallel_group.add_argument(
664+
"--enable-sequence-parallel",
665+
**parallel_kwargs["enable_sequence_parallel"])
659666
parallel_group.add_argument(
660667
"--enable-expert-parallel",
661668
**parallel_kwargs["enable_expert_parallel"])
@@ -1247,6 +1254,7 @@ def create_engine_config(
12471254
parallel_config = ParallelConfig(
12481255
pipeline_parallel_size=self.pipeline_parallel_size,
12491256
tensor_parallel_size=self.tensor_parallel_size,
1257+
context_parallel_size=self.context_parallel_size,
12501258
data_parallel_size=self.data_parallel_size,
12511259
data_parallel_rank=self.data_parallel_rank or 0,
12521260
data_parallel_external_lb=data_parallel_external_lb,
@@ -1255,6 +1263,7 @@ def create_engine_config(
12551263
data_parallel_rpc_port=data_parallel_rpc_port,
12561264
data_parallel_backend=self.data_parallel_backend,
12571265
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1266+
enable_sequence_parallel=self.enable_sequence_parallel,
12581267
enable_expert_parallel=self.enable_expert_parallel,
12591268
enable_eplb=self.enable_eplb,
12601269
num_redundant_experts=self.num_redundant_experts,

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import vllm.envs as envs
1212
from vllm.config import ParallelConfig
13-
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
13+
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank
1414
from vllm.logger import init_logger
1515
from vllm.model_executor.layers.quantization.base_config import (
1616
QuantizationConfig)
@@ -163,9 +163,11 @@ def make(
163163
@dataclass
164164
class FusedMoEParallelConfig:
165165
tp_size: int
166+
cp_size: int
166167
dp_size: int
167168
ep_size: int
168169
tp_rank: int
170+
cp_rank: int
169171
dp_rank: int
170172
ep_rank: int
171173

@@ -197,7 +199,7 @@ def use_flashinfer_cutlass_kernels(self):
197199
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
198200

199201
@staticmethod
200-
def make(tp_size_: int, dp_size_: int,
202+
def make(tp_size_: int, dp_size_: int, cp_size_: int,
201203
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
202204
"""
203205
Determine MoE parallel configuration. Based on the input `tp_size_`,
@@ -278,16 +280,20 @@ def flatten_tp_across_dp(dp_rank: int):
278280
tp_rank = dp_rank * tp_size_ + tp_rank
279281
return tp_size, tp_rank
280282

281-
use_ep = (dp_size_ * tp_size_ > 1
283+
use_ep = (dp_size_ * tp_size_ * cp_size_ > 1
282284
and vllm_parallel_config.enable_expert_parallel)
283285

284286
dp_size = dp_size_
285287
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
286288
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
289+
cp_size = cp_size_
290+
cp_rank = get_context_model_parallel_rank() if cp_size_ > 1 else 0
287291

288292
if not use_ep:
289293
return FusedMoEParallelConfig(tp_size=tp_size,
290294
tp_rank=tp_rank,
295+
cp_size=cp_size,
296+
cp_rank=cp_rank,
291297
dp_size=dp_size,
292298
dp_rank=dp_rank,
293299
ep_size=1,
@@ -297,10 +303,12 @@ def flatten_tp_across_dp(dp_rank: int):
297303
assert use_ep
298304
# In EP, each device owns a set of experts fully. There is no tensor
299305
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
300-
ep_size = tp_size
301-
ep_rank = tp_rank
306+
ep_size = tp_size * cp_size
307+
ep_rank = tp_rank + tp_size * cp_rank
302308
return FusedMoEParallelConfig(tp_size=1,
303309
tp_rank=0,
310+
cp_size=1,
311+
cp_rank=0,
304312
dp_size=dp_size,
305313
dp_rank=dp_rank,
306314
ep_size=ep_size,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.config import get_current_vllm_config
1515
from vllm.distributed import (get_dp_group, get_ep_group,
1616
get_tensor_model_parallel_world_size,
17+
get_context_model_parallel_world_size,
1718
tensor_model_parallel_all_reduce)
1819
from vllm.distributed.eplb.eplb_state import EplbState
1920
from vllm.forward_context import ForwardContext, get_forward_context
@@ -734,6 +735,7 @@ def __init__(
734735
tp_size: Optional[int] = None,
735736
ep_size: Optional[int] = None,
736737
dp_size: Optional[int] = None,
738+
cp_size: Optional[int] = None,
737739
prefix: str = "",
738740
custom_routing_function: Optional[Callable] = None,
739741
scoring_func: str = "softmax",
@@ -753,12 +755,15 @@ def __init__(
753755
get_tensor_model_parallel_world_size())
754756
dp_size_ = (dp_size
755757
if dp_size is not None else get_dp_group().world_size)
758+
cp_size_ = (cp_size
759+
if cp_size is not None else get_context_model_parallel_world_size())
756760

757761
vllm_config = get_current_vllm_config()
758762
self.moe_parallel_config: FusedMoEParallelConfig = (
759763
FusedMoEParallelConfig.make(
760764
tp_size_=tp_size_,
761765
dp_size_=dp_size_,
766+
cp_size_=cp_size_,
762767
vllm_parallel_config=vllm_config.parallel_config))
763768

764769
self.global_num_experts = num_experts + num_redundant_experts

vllm/v1/core/sched/output.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class NewRequestData:
3232
block_ids: tuple[list[int], ...]
3333
num_computed_tokens: int
3434
lora_request: Optional[LoRARequest]
35+
num_computed_tokens_of_cp_sp: Optional[list[list[int]]]
3536

3637
@classmethod
3738
def from_request(
@@ -50,6 +51,7 @@ def from_request(
5051
block_ids=block_ids,
5152
num_computed_tokens=request.num_computed_tokens,
5253
lora_request=request.lora_request,
54+
num_computed_tokens_of_cp_sp=request.num_computed_tokens_of_cp_sp,
5355
)
5456

5557
def __repr__(self):
@@ -93,6 +95,8 @@ class CachedRequestData:
9395
new_token_ids: list[list[int]]
9496
new_block_ids: list[tuple[list[int], ...]]
9597
num_computed_tokens: list[int]
98+
kv_rank: list[tuple[int]]
99+
num_computed_tokens_of_cp_sp: list[list[list[int]]]
96100

97101
@property
98102
def num_reqs(self) -> int:
@@ -106,6 +110,8 @@ def make_empty(cls) -> CachedRequestData:
106110
new_token_ids=[],
107111
new_block_ids=[],
108112
num_computed_tokens=[],
113+
kv_rank=[],
114+
num_computed_tokens_of_cp_sp=[],
109115
)
110116

111117

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,10 @@ def _make_cached_request_data(
635635
new_block_ids: list[tuple[list[int], ...]] = []
636636
num_computed_tokens: list[int] = []
637637

638+
# cp param
639+
kv_rank: list[tuple[int]] = []
640+
num_computed_tokens_of_cp_sp: list[list[list[int]]] = []
641+
638642
use_connector = self.connector is not None
639643
for req in itertools.chain(running_reqs, resumed_reqs):
640644
req_id = req.request_id
@@ -657,6 +661,8 @@ def _make_cached_request_data(
657661
new_token_ids.append([])
658662
new_block_ids.append(req_to_new_block_ids[req_id])
659663
num_computed_tokens.append(req.num_computed_tokens)
664+
kv_rank.append(req.kv_rank)
665+
num_computed_tokens_of_cp_sp.append(req.num_computed_tokens_of_cp_sp)
660666
# Because resumed_reqs is usually empty, it is more efficient to do
661667
# in-place appending so that we don't need to allocate a new list.
662668
resumed_from_preemption = [False] * len(running_reqs)
@@ -668,6 +674,8 @@ def _make_cached_request_data(
668674
new_token_ids=new_token_ids,
669675
new_block_ids=new_block_ids,
670676
num_computed_tokens=num_computed_tokens,
677+
kv_rank=kv_rank,
678+
num_computed_tokens_of_cp_sp=num_computed_tokens_of_cp_sp,
671679
)
672680

673681
def _try_schedule_encoder_inputs(

vllm/v1/executor/multiproc_executor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ def _init_executor(self) -> None:
5555
self.world_size = self.parallel_config.world_size
5656
tensor_parallel_size = self.parallel_config.tensor_parallel_size
5757
pp_parallel_size = self.parallel_config.pipeline_parallel_size
58-
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
58+
context_parallel_size = self.parallel_config.context_parallel_size
59+
assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, (
5960
f"world_size ({self.world_size}) must be equal to the "
6061
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
61-
f"_parallel_size ({pp_parallel_size}). ")
62+
f"_parallel_size ({pp_parallel_size}) x context"
63+
f"_parallel_size ({context_parallel_size}). ")
6264

6365
# Set multiprocessing envs that are common to V0 and V1
6466
set_multiprocessing_worker_envs(self.parallel_config)
@@ -323,7 +325,7 @@ def _get_output_rank(self) -> int:
323325
# 16-23, PP rank 2
324326
# 24-31, PP rank 3
325327
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
326-
return self.world_size - self.parallel_config.tensor_parallel_size
328+
return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size
327329

328330

329331
@dataclass

0 commit comments

Comments
 (0)