diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 149b502a85a7..5495640af07e 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -204,17 +204,21 @@ def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = { - # [MLA attention only] "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), ], + "bigcode/gpt_bigcode-santacoder": [ + CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2), + ], } CP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "deepseek-ai/DeepSeek-V2-Lite-Chat", + "bigcode/gpt_bigcode-santacoder", ] diff --git a/tests/models/registry.py b/tests/models/registry.py index fbc11c2ddfd4..237b3ec29362 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -265,7 +265,10 @@ def check_available_online( "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo( "bigcode/starcoder", - extras={"tiny": "bigcode/tiny_starcoder_py"}, + extras={ + "tiny": "bigcode/tiny_starcoder_py", + "santacoder": "bigcode/gpt_bigcode-santacoder", + }, min_transformers_version="4.55.1", transformers_version_reason="HF model broken in 4.55.0", ), diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 1234e1b2e46a..b6b7ecd2552a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -173,6 +173,7 @@ def cp_lse_ag_out_rs( cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext = None, + return_lse=False, ): """ cp_attn_out: [ B, H, D ] @@ -192,8 +193,15 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) + + if return_lse: + cp_num_heads = lse.shape[1] // cp_group.world_size + cp_rank = cp_group.rank_in_group + lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)] + return out, lse return out diff --git a/vllm/config/model.py b/vllm/config/model.py index a2dcf5210754..c0e68bd53d8c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1201,6 +1201,23 @@ def verify_with_parallel_config( "Supported models implement the `SupportsPP` interface." ) + decode_context_parallel_size = parallel_config.decode_context_parallel_size + if decode_context_parallel_size > 1 and not self.use_mla: + total_num_kv_heads = self.get_total_num_kv_heads() + assert tensor_parallel_size > total_num_kv_heads, ( + f"tensor parallel size {tensor_parallel_size} must be greater " + f"than total num kv heads {total_num_kv_heads} when enable " + f"decode context parallel for GQA/MQA" + ) + + max_dcp_size = tensor_parallel_size // total_num_kv_heads + assert decode_context_parallel_size <= max_dcp_size, ( + f"decode context parallel size must less than or equal to " + f"(tensor parallel size {tensor_parallel_size} // total " + f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, " + f"but got {decode_context_parallel_size}" + ) + def get_sliding_window(self) -> int | None: """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fb5ff499de2c..fa4e34536135 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -17,6 +17,7 @@ is_quantized_kv_cache, ) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, @@ -32,6 +33,7 @@ ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -147,6 +149,10 @@ class FlashAttentionMetadata: prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None + # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None @@ -216,6 +222,16 @@ def __init__( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = get_flash_attn_version() == 3 + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -306,7 +322,7 @@ def schedule( batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q, + num_heads_q=self.num_heads_q * self.dcp_world_size, num_heads_kv=self.num_heads_kv, headdim=self.headdim, cache_seqlens=seqlens, @@ -320,8 +336,35 @@ def schedule( return None use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( + self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() - if use_cascade: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: cu_prefix_query_lens = torch.tensor( [0, num_actual_tokens], dtype=torch.int32, device=self.device ) @@ -348,10 +391,6 @@ def schedule( causal=True, ) else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None scheduler_metadata = schedule( batch_size=num_reqs, cu_query_lens=query_start_loc, @@ -379,6 +418,8 @@ def schedule( seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, @@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -562,30 +605,45 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) - return output + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output # Cascade attention (rare case). cascade_attention( @@ -615,6 +673,86 @@ def forward( ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -684,6 +822,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: """Decide whether to use cascade attention. @@ -705,6 +844,9 @@ def use_cascade_attention( num_reqs = len(query_lens) if num_reqs < 8: return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index beb267f196fb..cb5855548098 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -345,6 +345,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: return False diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0d99597fa641..f9defc9595f4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1523,6 +1523,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, num_sms=self.num_sms, + dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0