Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions tests/distributed/test_sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool
Expand Down Expand Up @@ -60,25 +61,50 @@ def __post_init__(self):
def detailed(
*,
tp_base: int = 2,
pp_base: int = 1,
multi_node_only: bool = False,
task: TaskOption = "auto",
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
Expand All @@ -94,13 +120,20 @@ def detailed(
def fast(
*,
tp_base: int = 2,
pp_base: int = 1,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
Expand Down Expand Up @@ -136,6 +169,7 @@ def _compare_sp(
):
(
tp_size,
pp_size,
sp_enabled,
eager_mode,
chunked_prefill,
Expand Down Expand Up @@ -167,7 +201,6 @@ def _compare_sp(
else:
model_info.check_available_online(on_fail="skip")

pp_size = 1
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
Expand Down Expand Up @@ -256,7 +289,7 @@ def _compare_sp(

SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
}

SP_TEST_MODELS = [
Expand Down
12 changes: 0 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4222,18 +4222,6 @@ def __post_init__(self):
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()

if self.parallel_config is not None and \
self.parallel_config.tensor_parallel_size > 1 and \
self.parallel_config.pipeline_parallel_size > 1 and \
self.compilation_config is not None and \
self.compilation_config.pass_config is not None and \
self.compilation_config.pass_config.enable_sequence_parallelism:
logger.warning_once(
"Sequence parallelism is not supported with pipeline "
"parallelism. Disabling sequence parallelism.")
self.compilation_config.pass_config.\
enable_sequence_parallelism = False

self._set_cudagraph_sizes()

if self.cache_config is not None and \
Expand Down
30 changes: 26 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,11 +1145,23 @@ def execute_model(
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism
is_residual_scattered = tp > 1 and enabled_sp \
and num_input_tokens % tp == 0

for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
is_scattered = "residual" and is_residual_scattered
copy_len = num_input_tokens // tp if is_scattered else \
num_input_tokens
self.intermediate_tensors[k][:copy_len].copy_(
v[:copy_len], non_blocking=True)

intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
k:
v[:num_input_tokens // tp] if k == "residual"
and is_residual_scattered else v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})

Expand All @@ -1172,6 +1184,7 @@ def execute_model(

if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
# print(f"cascade pp is not last rank, return hidden states")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for debug usage, will remove.

return hidden_states

sample_hidden_states = hidden_states[logits_indices]
Expand Down Expand Up @@ -1568,8 +1581,17 @@ def _dummy_run(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))

tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism
is_residual_scattered = tp > 1 and enabled_sp \
and num_tokens % tp == 0

intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
k:
v[:num_tokens // tp] if k == "residual"
and is_residual_scattered else v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})

Expand Down