Skip to content

Commit 4516d44

Browse files
gjc0824gaojingchun (A)Jingchun Gaopisceskkk
authored
[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)
Signed-off-by: gaojc <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Co-authored-by: gaojingchun (A) <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: QiuChunshuo <[email protected]>
1 parent 41b92f7 commit 4516d44

File tree

5 files changed

+331
-51
lines changed

5 files changed

+331
-51
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
3939
class CPTestOptions(NamedTuple):
4040
multi_node_only: bool
4141
load_format: str | None = None
42+
attn_backend: str | None = None
4243

4344

4445
@dataclass
@@ -58,6 +59,7 @@ def detailed(
5859
multi_node_only: bool = False,
5960
runner: RunnerOption = "auto",
6061
load_format: str | None = None,
62+
attn_backend: str | None = None,
6163
):
6264
parallel_setups = []
6365
for eager_mode_val in [False]:
@@ -79,7 +81,9 @@ def detailed(
7981
distributed_backends=["mp"],
8082
runner=runner,
8183
test_options=CPTestOptions(
82-
multi_node_only=multi_node_only, load_format=load_format
84+
multi_node_only=multi_node_only,
85+
load_format=load_format,
86+
attn_backend=attn_backend,
8387
),
8488
)
8589

@@ -117,7 +121,7 @@ def _compare_cp_with_tp(
117121
chunked_prefill,
118122
) = parallel_setup
119123

120-
multi_node_only, load_format = test_options
124+
multi_node_only, load_format, attn_backend = test_options
121125

122126
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
123127
model_info.check_transformers_version(on_fail="skip")
@@ -177,6 +181,13 @@ def _compare_cp_with_tp(
177181
if hf_overrides:
178182
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
179183

184+
if not attn_backend:
185+
cp_env = tp_env = {}
186+
else:
187+
cp_env = tp_env = {
188+
"VLLM_ATTENTION_BACKEND": attn_backend,
189+
}
190+
180191
cp_args = [
181192
*common_args,
182193
"--tensor-parallel-size",
@@ -205,6 +216,8 @@ def _compare_cp_with_tp(
205216
model_id,
206217
cp_args,
207218
tp_args,
219+
cp_env,
220+
tp_env,
208221
method=method,
209222
max_wait_seconds=720,
210223
)

vllm/config/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,14 @@ def verify_with_parallel_config(
11831183
f"but got {decode_context_parallel_size}"
11841184
)
11851185

1186+
num_q_per_kv = total_num_attention_heads // total_num_kv_heads
1187+
assert num_q_per_kv % decode_context_parallel_size == 0, (
1188+
f"Total number of q per kv attn heads ({num_q_per_kv})"
1189+
" must be divisible by dcp world size when enable "
1190+
"decode context parallel for GQA "
1191+
f"({parallel_config.decode_context_parallel_size})."
1192+
)
1193+
11861194
def get_sliding_window(self) -> int | None:
11871195
"""Get the sliding window size from the HF text config if present."""
11881196
return getattr(self.hf_text_config, "sliding_window", None)

vllm/utils/flashinfer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def use_trtllm_attention(
259259
num_kv_heads: int,
260260
num_tokens: int,
261261
max_seq_len: int,
262+
dcp_world_size: int,
262263
kv_cache_dtype: str,
263264
q_dtype: torch.dtype,
264265
is_prefill: bool,
@@ -272,6 +273,14 @@ def use_trtllm_attention(
272273
if force_use_trtllm is not None and not force_use_trtllm:
273274
return False
274275

276+
# Decode context parallel is not supported
277+
if dcp_world_size > 1:
278+
logger.warning_once(
279+
"Trtllm does not support returning LSE and as a result "
280+
"does not support DCP, reverting to FlashInfer"
281+
)
282+
return False
283+
275284
# The platform is not supported
276285
if not supports_trtllm_attention():
277286
if force_use_trtllm:

0 commit comments

Comments
 (0)