diff --git a/hopper/block.h b/hopper/block.h index 3da119cae9..3de4cab26e 100644 --- a/hopper/block.h +++ b/hopper/block.h @@ -38,8 +38,13 @@ struct BlockMN { // TODO: check off-by-1 error if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } // If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left) + // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1. + // cp_world_size is guaranteed to be greater than 0 n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN)); + cute::ceil_div( + cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank, + seqlen_info.cp_world_size), + kBlockN)); } // Now, only adjust n_block_min if split int n_block_min = 0; diff --git a/hopper/flash.h b/hopper/flash.h index 28997613dc..dcc912249a 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -161,6 +161,11 @@ struct Flash_fwd_params : public Qkv_params { // The S extra matrix, (num_heads) void *__restrict__ s_aux_ptr; + + // CP (Context Parallelism) parameters + int cp_world_size; + int cp_rank; + int *__restrict__ cp_tot_seqused_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 0cfebb0146..43c442f43d 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -707,7 +707,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int num_splits, std::optional pack_gqa_, int const sm_margin, - std::optional &s_aux_ // (h) + std::optional &s_aux_, // (h) + int const cp_world_size, // context parallelism (cp) world size + int const cp_rank, // cp rank + std::optional &cp_tot_seqused_k_ // b. total seqused_k in cp world ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -845,6 +848,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } + if (cp_tot_seqused_k_.has_value()) { + auto cp_tot_seqused_k = cp_tot_seqused_k_.value(); + TORCH_CHECK(cp_tot_seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(cp_tot_seqused_k); CHECK_CONTIGUOUS(cp_tot_seqused_k); + CHECK_SHAPE(cp_tot_seqused_k, batch_size); + } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); @@ -1154,6 +1163,14 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.s_aux_ptr = nullptr; } + params.cp_world_size = cp_world_size; + params.cp_rank = cp_rank; + params.cp_tot_seqused_k = cp_tot_seqused_k_.has_value() ? + static_cast(cp_tot_seqused_k_.value().data_ptr()) : nullptr; + TORCH_CHECK(cp_world_size > 0, "cp_world_size must be positive, required by downstream unified code path. Use 1 if CP is not enabled."); + TORCH_CHECK(cp_world_size != 1 || cp_rank == 0, "When context parallelism is disabled, cp_rank must be zero"); + TORCH_CHECK(cp_world_size == 1 || cp_tot_seqused_k_.has_value(), "cp_tot_seqused_k_ must be provided when context parallelism is enabled."); + #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif @@ -1670,4 +1687,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); } -#endif \ No newline at end of file +#endif diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index ad2c515f9d..338c9d408b 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -52,7 +52,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int num_splits, std::optional pack_gqa_, int const sm_margin, - std::optional &s_aux_ + std::optional &s_aux_, + int const cp_world_size, + int const cp_rank, + std::optional &cp_tot_seqused_k ); // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available @@ -120,7 +123,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int num_splits," " bool? pack_gqa," " int sm_margin," - " Tensor? s_aux) -> Tensor[]"); + " Tensor? s_aux," + " int cp_world_size," + " int cp_rank," + " Tensor? cp_tot_seqused_k) -> Tensor[]"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ops.def("get_scheduler_metadata(" @@ -151,4 +157,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); } -REGISTER_EXTENSION(TORCH_EXTENSION_NAME); \ No newline at end of file +REGISTER_EXTENSION(TORCH_EXTENSION_NAME); diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index d3150fbb67..6f06f96c83 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -49,7 +49,10 @@ def _flash_attn_forward( num_splits=1, pack_gqa=None, sm_margin=0, - s_aux=None): + s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -95,7 +98,10 @@ def _flash_attn_forward( num_splits, pack_gqa, sm_margin, - s_aux + s_aux, + cp_world_size, + cp_rank, + cp_tot_seqused_k, ) return out, softmax_lse, *rest @@ -260,6 +266,9 @@ def forward( deterministic=False, sm_margin=0, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -285,6 +294,9 @@ def forward( pack_gqa=pack_gqa, sm_margin=sm_margin, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqused_k=cp_tot_seqused_k, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.save_for_backward(q, k, v, out, softmax_lse) @@ -351,6 +363,9 @@ def forward( deterministic=False, sm_margin=0, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=0, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -380,6 +395,9 @@ def forward( pack_gqa=pack_gqa, sm_margin=sm_margin, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqused_k=cp_tot_seqused_k, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -497,6 +515,9 @@ def flash_attn_func( deterministic=False, sm_margin=0, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -558,6 +579,9 @@ def flash_attn_func( deterministic, sm_margin, s_aux, + cp_world_size, + cp_rank, + cp_tot_seqused_k, ) @@ -582,6 +606,9 @@ def flash_attn_varlen_func( deterministic=False, sm_margin=0, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): return FlashAttnVarlenFunc.apply( q, @@ -604,6 +631,9 @@ def flash_attn_varlen_func( deterministic, sm_margin, s_aux, + cp_world_size, + cp_rank, + cp_tot_seqused_k, ) @@ -642,6 +672,9 @@ def flash_attn_with_kvcache( sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -769,6 +802,9 @@ def flash_attn_with_kvcache( pack_gqa=pack_gqa, sm_margin=sm_margin, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqused_k=cp_tot_seqused_k, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 242da9bf8a..fd44f33889 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -347,7 +347,10 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary + params.mainloop.seqlens_rotary, + params.mainloop.cp_world_size, + params.mainloop.cp_rank, + params.mainloop.cp_tot_seqused_k }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.load_kv_new( @@ -396,7 +399,9 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary + params.mainloop.seqlens_rotary, params.mainloop.cp_world_size, + params.mainloop.cp_rank, + params.mainloop.cp_tot_seqused_k }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 616380b3d2..73a128bc98 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -129,7 +129,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, params.leftpad_k, params.seqlens_rotary, - static_cast(params.s_aux_ptr) + static_cast(params.s_aux_ptr), + params.cp_world_size, params.cp_rank, params.cp_tot_seqused_k }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), @@ -156,6 +157,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.cp_world_size, + params.cp_rank, }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 7dcae77109..a7c42fe0ff 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -215,6 +215,9 @@ struct CollectiveMainloopFwdSm80 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; + int cp_world_size; + int cp_rank; + int const* const cp_tot_seqused_k = nullptr; }; // Device side kernel params diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 50a2d7fb80..6e0d8b768b 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -412,6 +412,10 @@ struct CollectiveMainloopFwdSm90 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; + // Context parallelism (CP) parameters + int const cp_world_size = 1; + int const cp_rank = 0; + int const* const cp_tot_seqused_k = nullptr; }; // Device side kernel params @@ -469,6 +473,9 @@ struct CollectiveMainloopFwdSm90 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; + int cp_world_size = 1; + int cp_rank = 0; + int const* const cp_tot_seqused_k = nullptr; }; static Params @@ -584,7 +591,8 @@ struct CollectiveMainloopFwdSm90 { args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_S_aux}; + args.ptr_S_aux, + args.cp_world_size, args.cp_rank, args.cp_tot_seqused_k}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -1093,7 +1101,8 @@ struct CollectiveMainloopFwdSm90 { // But we subtract n_offset for consistency in mask calculations flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/, - params.qhead_per_khead_divmod + params.qhead_per_khead_divmod, + params.cp_world_size, params.cp_rank, seqlen_info.tot_seqlen_k ); float softcap_val = params.softcap_val; @@ -1275,8 +1284,13 @@ struct CollectiveMainloopFwdSm90 { auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); // If local, blocking (window_size_right + window_size_left) + // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1. + // cp_world_size is guaranteed to be greater than 0 int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + std::max(n_block_min, + (m_idx_min + seqlen_info.tot_seqlen_k - seqlen_q + params.window_size_right) / + seqlen_info.cp_world_size / + kBlockN); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); @@ -1285,10 +1299,15 @@ struct CollectiveMainloopFwdSm90 { int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; // If local, blocking (m_idx_max - m_idx_min) + // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1. + // cp_world_size is guaranteed to be greater than 0 int const n_block_min_before_local_mask = !Is_local ? n_block_min : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + cute::ceil_div( + cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q - params.window_size_left - seqlen_info.cp_rank, + seqlen_info.cp_world_size), + kBlockN)); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { diff --git a/hopper/mask.h b/hopper/mask.h index 02d046268c..c3cba01250 100644 --- a/hopper/mask.h +++ b/hopper/mask.h @@ -23,11 +23,13 @@ struct Mask { int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; cutlass::FastDivmod const qhead_per_khead_divmod; + int const cp_world_size, cp_rank, tot_seqlen_k; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, - cutlass::FastDivmod const &qhead_per_khead_divmod) + cutlass::FastDivmod const &qhead_per_khead_divmod, + const int cp_world_size = 1, const int cp_rank = 0, const int tot_seqlen_k = 0) : thread_idx(thread_idx) , seqlen_q(seqlen_q) , seqlen_k(seqlen_k) @@ -35,6 +37,9 @@ struct Mask { , window_size_right(window_size_right) , sink_token_length(sink_token_length) , qhead_per_khead_divmod(qhead_per_khead_divmod) + , cp_world_size(cp_world_size) + , cp_rank(cp_rank) + , tot_seqlen_k(tot_seqlen_k) { }; @@ -94,7 +99,19 @@ struct Mask { : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } + int col_idx = int(get(t0ScS_rowcol(_0{}, n))); + if (cp_world_size > 1) { + int local_k_idx = int(get(t0ScS_rowcol(_0{}, n))) + get(tScS_rowcol(_0{}, _0{})) + n_block * kBlockN; + int abs_k_idx = local_k_idx * cp_world_size + cp_rank; + int k_limit = row_idx + tot_seqlen_k - seqlen_q; + if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx >= tot_seqlen_k)) { + tSrS_rowcol(m, n) = -INFINITY; + } + } else { + if (col_idx >= col_limit_right) { + tSrS_rowcol(m, n) = -INFINITY; + } + } } } } else { diff --git a/hopper/seqlen.h b/hopper/seqlen.h index 5547238b34..c632216256 100644 --- a/hopper/seqlen.h +++ b/hopper/seqlen.h @@ -33,12 +33,15 @@ struct SeqlenInfoQK { int const offset_q, offset_k, offset_q_padded; int const seqlen_q, seqlen_k; + int const cp_world_size; + int const tot_seqlen_k; CUTLASS_DEVICE SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, - int const* const seqused_q, int const* const seqused_k - ) + int const* const seqused_q, int const* const seqused_k, + int const cp_world_size=1, + int const* const cp_tot_seqused_k=nullptr) : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch @@ -52,6 +55,10 @@ struct SeqlenInfoQK { , seqlen_k(!Varlen ? seqlen_k_static : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) + , cp_world_size(cp_world_size) + , tot_seqlen_k(cp_tot_seqused_k == nullptr and cp_world_size <= 1 + ? seqlen_k + : cp_tot_seqused_k[bidb]) { } @@ -65,12 +72,16 @@ struct SeqlenInfoQKNewK { int const leftpad_k; int const offset_q, offset_k, offset_k_new; int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; + int const cp_world_size; + int const cp_rank; + int const tot_seqlen_k; CUTLASS_DEVICE SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, - int const* const seqlens_rotary + int const* const seqlens_rotary, int const cp_world_size=1, int const cp_rank=0, + int const* const cp_tot_seqused_k=nullptr ) : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) @@ -87,6 +98,11 @@ struct SeqlenInfoQKNewK { : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) + , cp_world_size(cp_world_size) + , cp_rank(cp_rank) + , tot_seqlen_k(cp_tot_seqused_k == nullptr and cp_world_size <= 1 + ? seqlen_k + : cp_tot_seqused_k[bidb]) { } diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 9b390fa431..0b1f58ce70 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -122,8 +122,19 @@ ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@pytest.mark.parametrize( + "cp_world_size,cp_rank,cp_tot_seqlen_k_offset", + [ + (8,0,1), + (8,7,0), + (4,3,2), + (2,0,0), + (1,0,0), # 1 means disabling cp + ], +) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink, + cp_world_size, cp_rank, cp_tot_seqlen_k_offset ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") @@ -131,6 +142,8 @@ def test_flash_attn_output( pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") if test_sink and has_qv_: pytest.skip("Sink disabled for Qv") + if cp_world_size > 1 and local: + pytest.skip("context parallelism is not supported with local attention yet") device = "cuda" # set seed torch.random.manual_seed(0) @@ -151,6 +164,8 @@ def test_flash_attn_output( s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None # print("s_aux ", s_aux) + cp_tot_seqlen_k = seqlen_k * cp_world_size + cp_tot_seqlen_k_offset + cp_tot_seqlen_k = torch.full((batch_size,), cp_tot_seqlen_k, device=device, dtype=torch.int32) if test_sink: dv_vals = [d] for dv in dv_vals: @@ -168,7 +183,7 @@ def test_flash_attn_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + window_size = (-1, -1) if not local else torch.randint(0, cp_tot_seqlen_k[0], (2,)) # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] @@ -190,6 +205,9 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqlen_k=cp_tot_seqlen_k, ) out_pt, attn_pt = attention_ref( q_ref, @@ -206,6 +224,9 @@ def test_flash_attn_output( reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqlen_k=cp_tot_seqlen_k, ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() @@ -238,6 +259,9 @@ def test_flash_attn_output( pack_gqa=pack_gqa, num_splits=num_splits, s_aux=s_aux, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqused_k=cp_tot_seqlen_k, ) print("Pack GQA =", pack_gqa) print("Num splits =", num_splits) diff --git a/hopper/test_util.py b/hopper/test_util.py index 6709b79248..a24faf96d3 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -163,7 +163,24 @@ def construct_local_mask( key_padding_mask=None, key_leftpad=None, device=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqlen_k=None, ): + if cp_world_size > 1: + return construct_cp_mask( + seqlen_q, + seqlen_k, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqlen_k=cp_tot_seqlen_k, + window_size=window_size, + sink_token_length=sink_token_length, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + key_leftpad=key_leftpad, + device=device, + ) row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: @@ -189,6 +206,106 @@ def construct_local_mask( torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), ) +def construct_cp_mask( + seqlen_q, + seqlen_k, + cp_world_size=1, + cp_rank=0, + cp_tot_seqlen_k=None, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + """ + Construct attention mask for context parallelism (DCP). + + This function creates a mask that handles both local windowing and context parallelism. + For DCP, each rank only sees a subset of KV tokens (interleaved), and the mask + must account for the global positions when applying causal or windowing constraints. + + Args: + seqlen_q: Length of query sequence + seqlen_k: Length of key sequence (local to this rank) + cp_world_size: Number of context parallel ranks + cp_rank: Current rank ID (0 to cp_world_size-1) + cp_tot_seqlen_k: Total lengths of key sequence in cp world + window_size: (left_window, right_window), -1 = infinite + sink_token_length: Number of "sink" tokens that can always be attended to + query_padding_mask: Which query positions are valid + key_padding_mask: Which key positions are valid + key_leftpad: Left padding for keys (per batch) + device: Device to place tensors on + + Returns: + mask: Boolean tensor of shape [seqlen_q, seqlen_k] where True = masked out + """ + # Create position indices + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") # [seqlen_q, 1] + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) # [seqlen_k] + + # Handle left padding if present + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + + # Calculate effective sequence lengths + sk = ( + cp_tot_seqlen_k[0] + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") * cp_world_size + ) + sq = ( + torch.tensor(seqlen_q, device=device, dtype=torch.long) # Global seqlen_k for DCP + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + + if cp_world_size > 1: + # DCP masking logic + # Convert local K indices to global (absolute) K positions + # local_k_idx * cp_world_size + cp_rank gives the global position + abs_k_idx = col_idx * cp_world_size + cp_rank # [seqlen_k] -> global positions + + # Query global positions: row_idx + seqlen_k_global - seqlen_q + # This handles the case where query and key sequences might have different lengths + abs_q_idx = row_idx + sk - sq # [seqlen_q, 1] -> global query positions + + if window_size[0] < 0: + # Infinite left window - essentially causal masking with right window + mask = abs_k_idx > abs_q_idx + window_size[1] + else: + # Finite window - sliding window attention + # Right boundary: abs_k_idx > abs_q_idx + window_size[1] + right_mask = abs_k_idx > torch.minimum(abs_q_idx + window_size[1], sk) + + # Left boundary: abs_k_idx < abs_q_idx - window_size[0], but exclude sink tokens + left_mask = torch.logical_and( + abs_k_idx < abs_q_idx - window_size[0], + abs_k_idx >= sink_token_length + ) + + mask = torch.logical_or(right_mask, left_mask) + + else: + # Non-DCP case: fall back to original construct_local_mask logic + if window_size[0] < 0: + mask = col_idx > row_idx + sk - sq + window_size[1] + else: + sk_local = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk // cp_world_size + mask = torch.logical_or( + col_idx > torch.minimum(row_idx + sk_local - sq + window_size[1], sk_local), + torch.logical_and( + col_idx < row_idx + sk_local - sq - window_size[0], + col_idx >= sink_token_length + ), + ) + + return mask + def attention_ref( q, @@ -209,7 +326,10 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, - s_aux=None + s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqlen_k=None, ): """ Arguments: @@ -229,6 +349,9 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. s_aux: (nheads) + cp_world_size: Number of context parallel ranks + cp_rank: Current rank ID (0 to cp_world_size-1) + cp_tot_seqlen_k: (batch_size) total seqlen of k/v in cp world Output: output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout @@ -276,6 +399,9 @@ def attention_ref( key_padding_mask, key_leftpad=key_leftpad, device=q.device, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqlen_k=cp_tot_seqlen_k, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 53651d5c84..c147725dd8 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -26,6 +26,9 @@ struct TileSchedulerArguments { int const* const seqused = nullptr; // int const* const num_m_blocks_ptr = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + // CP (Context Parallelism) parameters + int const cp_world_size = 1; + int const cp_rank = 0; }; /////////////////////////////////////////////////////////////////////////////// @@ -46,6 +49,8 @@ class SingleTileScheduler { int const* const cu_seqlens; int const* const seqused; int const* const num_splits_dynamic_ptr = nullptr; + int const cp_world_size = 1; + int const cp_rank = 0; }; static Params @@ -56,7 +61,8 @@ class SingleTileScheduler { args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, - args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr, + args.cp_world_size, args.cp_rank}; } static dim3 diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 06de7fd17b..27ef088cca 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -146,6 +146,9 @@ def flash_attn_varlen_func( # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -279,7 +282,10 @@ def flash_attn_varlen_func( num_splits, None, # pack_gqa 0, # sm_margin - s_aux # s_aux + s_aux, # s_aux + cp_world_size, + cp_rank, + cp_tot_seqused_k, ) else: raise ValueError(f"Unsupported FA version: {fa_version}") @@ -316,6 +322,9 @@ def flash_attn_with_kvcache( # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from