|
617 | 617 | inplace : (out_grad -> x_grad) |
618 | 618 |
|
619 | 619 | - backward_op : flash_attn_grad |
620 | | - forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
| 620 | + forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
621 | 621 | args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) |
622 | 622 | output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) |
623 | 623 | infer_meta : |
|
628 | 628 | data_type: q |
629 | 629 |
|
630 | 630 | - backward_op : flash_attn_unpadded_grad |
631 | | - forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
| 631 | + forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
632 | 632 | args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) |
633 | 633 | output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) |
634 | 634 | infer_meta : |
|
0 commit comments