Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions sgl-kernel/python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out


def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
):
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)

if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)

out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k,
v,
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)

return (out, softmax_lse, *rest) if return_softmax_lse else out
Loading