diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 91fe2a918f..6d9b918e68 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -214,7 +214,6 @@ def _flash_attn_varlen_forward_fake( paged_kv = block_table is not None batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape - out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) @@ -252,8 +251,7 @@ def _flash_attn_backward( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, @@ -281,7 +279,7 @@ def _flash_attn_backward( None, rng_state, ) - return softmax_d + return dq.clone(), dk.clone(), dv.clone(), softmax_d @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") @@ -304,18 +302,17 @@ def _flash_attn_backward_fake( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) - - return softmax_d + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) + # dq, dk, dv are already allocated in the fwd pass + # we are passing them here to match the cpp signature and help torch.compile in infering shape during tracing + # without this torch.compile will struggels infer the shape of softmax_d + return dq, dk, dv, softmax_d if torch.__version__ >= "2.4.0": @@ -348,8 +345,7 @@ def _flash_attn_varlen_backward( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, @@ -382,9 +378,8 @@ def _flash_attn_varlen_backward( None, rng_state, ) - # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return softmax_d + # return clones else torch.compile will about mutated tensors being returned + return dq.clone(), dk.clone(), dv.clone(), softmax_d @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward") @@ -411,20 +406,22 @@ def _flash_attn_varlen_backward_fake( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d + + # The CUDA kernel appears to round up max_seqlen_q to a multiple of 128 + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((batch_size, num_heads, max_seqlen_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128)), device=q.device, dtype=torch.float32) + + # dq, dk, dv are already allocated in the fwd pass + # we are passing them here to match the cpp signature and help torch.compile in infering shape during tracing + # without this torch.compile will struggels infer the shape of softmax_d + return dq, dk, dv, softmax_d if torch.__version__ >= "2.4.0": diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 84963501b6..201797436f 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -302,8 +302,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize("compiled", [False, True]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, compiled ): device = "cuda" # set seed @@ -343,7 +344,8 @@ def test_flash_attn_output( return_attn_probs=True, ) else: - out, lse, S_dmask = flash_attn_func( + flash_func = torch.compile(flash_attn_func) if compiled else flash_attn_func + out, lse, S_dmask = flash_func( q, k, v, @@ -519,8 +521,9 @@ def test_flash_attn_output( ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize("compiled", [False, True]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, compiled ): device = "cuda" # set seed @@ -598,7 +601,8 @@ def test_flash_attn_varlen_output( dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) - out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( + flash_varlen_func = torch.compile(flash_attn_varlen_func) if compiled else flash_attn_varlen_func + out_unpad, sm_lse, S_dmask = flash_varlen_func( q_unpad, k_unpad, v_unpad,