2525
2626#ifdef PADDLE_WITH_FLASHATTN
2727#include " paddle/phi/backends/dynload/flashattn.h"
28+ #include " paddle/phi/kernels/gpu/flash_attn_utils.h"
2829#endif
2930
3031DECLARE_bool (cudnn_deterministic);
@@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
5556 ctx.template Alloc <T>(dk);
5657 ctx.template Alloc <T>(dv);
5758
58- cudaStream_t stream = ctx.stream ();
59- bool is_bf16 = q.dtype () == DataType::BFLOAT16 ? true : false ;
59+ const cudaStream_t stream = ctx.stream ();
6060
6161 // q,k,v [total_*, num_heads, head_dim]
6262
6363 auto dims = q.dims ();
64- int64_t total_q = dims[0 ];
65- int64_t num_heads = dims[1 ];
66- int64_t head_size = dims[2 ];
67-
68- int64_t total_k = k.dims ()[0 ];
69- int64_t batch_size = cu_seqlens_q.numel () - 1 ;
70-
71- int num_splits = 0 ; // 0 for an internal heuristic, which is optimal
72- if (FLAGS_cudnn_deterministic) {
73- num_splits = 1 ;
74- }
75- bool zero_tensors = false ;
76-
77- const int64_t * seed_offset_data = seed_offset.data <int64_t >();
78- uint64_t seed = static_cast <uint64_t >(seed_offset_data[0 ]);
79- uint64_t offset = static_cast <uint64_t >(seed_offset_data[1 ]);
80-
81- VLOG (4 ) << " FlashAttn bwd seed: " << seed << " , offset: " << offset
82- << " , num_splits:" << num_splits;
83-
84- int64_t seq_len_q = ((max_seqlen_q + 16 - 1 ) / 16 ) * 16 ;
85- DenseTensor dsoftmax = Empty<float >(ctx, {batch_size, num_heads, seq_len_q});
86-
87- uint64_t workspace_size;
88-
89- // calculate workspace size before execution
90- bool succ = phi::dynload::flash_attn_bwd (
91- q.data (),
92- k.data (),
93- v.data (),
94- dq->data (),
95- dk->data (),
96- dv->data (),
97- nullptr , // for calculation workspace size
98- dout.data (),
99- cu_seqlens_q.data (),
100- cu_seqlens_k.data (),
101- total_q,
102- total_k,
103- batch_size,
104- num_heads,
105- head_size,
106- max_seqlen_q,
107- max_seqlen_k,
108- dropout,
109- scale,
110- zero_tensors,
111- causal,
112- is_bf16,
113- num_splits,
114- const_cast <float *>(softmax_lse.data <float >()),
115- dsoftmax.data (),
116- nullptr ,
117- &workspace_size,
118- stream,
119- seed,
120- offset);
121-
122- if (!succ) {
123- PADDLE_THROW (phi::errors::External (phi::dynload::flash_attn_error ()));
124- }
125-
126- DenseTensor workspace;
127- if (workspace_size > 0 ) {
128- workspace = Empty<float >(ctx, {int64_t (workspace_size / sizeof (float ))});
129- }
130-
131- succ = phi::dynload::flash_attn_bwd (
132- q.data (),
133- k.data (),
134- v.data (),
135- dq->data (),
136- dk->data (),
137- dv->data (),
138- out.data (),
139- dout.data (),
140- cu_seqlens_q.data (),
141- cu_seqlens_k.data (),
142- total_q,
143- total_k,
144- batch_size,
145- num_heads,
64+ const int64_t total_q = dims[0 ];
65+ const int batch_size = cu_seqlens_q.numel () - 1 ;
66+ const int num_heads = dims[1 ];
67+ const int head_size_og = dout.dims ()[2 ];
68+ const int head_size = dims[2 ];
69+ const int total_k = k.dims ()[0 ];
70+ const int num_heads_k = k.dims ()[1 ];
71+
72+ // TODO(umiswing): add deterministic in fa2.
73+ // int num_splits = 0; // 0 for an internal heuristic, which is optimal
74+ // if (FLAGS_cudnn_deterministic) {
75+ // num_splits = 1;
76+ // }
77+
78+ const bool zero_tensors = false ;
79+
80+ // TODO(umiswing): add shape check
81+ PADDLE_ENFORCE_EQ (
82+ head_size_og,
14683 head_size,
147- max_seqlen_q,
148- max_seqlen_k,
149- dropout,
150- scale,
151- zero_tensors,
152- causal,
153- is_bf16,
154- num_splits,
155- const_cast <float *>(softmax_lse.data <float >()),
156- dsoftmax.data (),
157- workspace_size > 0 ? workspace.data () : nullptr ,
158- &workspace_size,
159- stream,
160- seed,
161- offset);
84+ phi::errors::InvalidArgument (
85+ " flash_attn_bwd receive input with head_size_og == head_size" ));
86+
87+ FlashAttnBwdParamsV2 params =
88+ FlashAttnBwdParamsV2 (ctx,
89+ batch_size,
90+ max_seqlen_q,
91+ max_seqlen_k,
92+ num_heads,
93+ num_heads_k,
94+ head_size,
95+ dropout,
96+ scale,
97+ causal,
98+ q.dtype (),
99+ seed_offset.data <int64_t >());
100+
101+ VLOG (4 ) << " FlashAttn bwd seed: " << params.seed
102+ << " , offset: " << params.offset ;
103+
104+ const bool succ =
105+ phi::dynload::flash_attn_varlen_bwd (dout.data (),
106+ q.data (),
107+ k.data (),
108+ v.data (),
109+ out.data (),
110+ params.softmax_d .data (),
111+ softmax_lse.data (),
112+ cu_seqlens_q.data <int32_t >(),
113+ cu_seqlens_k.data <int32_t >(),
114+ params.rng_state .data (),
115+ dq->data (),
116+ dk->data (),
117+ dv->data (),
118+ params.dq_accum .data (),
119+ params.batch_size ,
120+ params.max_seqlen_q ,
121+ params.max_seqlen_k ,
122+ params.seqlen_q_rounded ,
123+ params.seqlen_k_rounded ,
124+ params.num_heads ,
125+ params.num_heads_k ,
126+ params.head_size ,
127+ params.head_size_rounded ,
128+ params.dropout ,
129+ params.scale ,
130+ params.causal ,
131+ params.is_bf16 ,
132+ stream,
133+ params.seed ,
134+ params.offset );
162135
163136 if (!succ) {
164137 PADDLE_THROW (phi::errors::External (phi::dynload::flash_attn_error ()));
165138 }
166-
139+ #else
140+ PADDLE_THROW (phi::errors::Unimplemented (
141+ " FlashAttention is unsupported, please set use_flash_attn to false." ));
167142#endif
168143}
169144
@@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
185160 // q,k,v [batch_size, seq_len, num_heads, head_dim]
186161
187162 auto dims = q.dims ();
188- int64_t batch_size = dims[0 ];
189- int64_t seq_len_q = dims[1 ];
190- int64_t num_heads = dims[2 ];
191- int64_t head_size = dims[3 ];
192-
193- int64_t seq_len_k = k.dims ()[1 ];
194-
195- int64_t total_q = batch_size * seq_len_q;
196- int64_t total_k = batch_size * seq_len_k;
197-
198- float scale = 1 .0f / std::sqrt (head_size);
163+ const int batch_size = dims[0 ];
164+ const int seqlen_q = dims[1 ];
165+ const int num_heads = dims[2 ];
166+ const int head_size_og = dout.dims ()[3 ];
167+ const int head_size = dims[3 ];
168+ const int seqlen_k = k.dims ()[1 ];
169+ const int num_heads_k = k.dims ()[2 ];
170+
171+ // TODO(umiswing): add shape check
172+ PADDLE_ENFORCE_EQ (
173+ head_size_og,
174+ head_size,
175+ phi::errors::InvalidArgument (
176+ " flash_attn_bwd receive input with head_size_og == head_size" ));
199177
200178 VLOG (4 ) << " FlashAttn bwd dims q[" << q.dims () << " ], k[" << k.dims ()
201179 << " ], v[" << v.dims () << " ]" ;
202180
203- DenseTensor q_t_s, k_t_s, v_t_s;
204- q_t_s.ShareDataWith (q).Resize ({total_q, num_heads, head_size});
205- k_t_s.ShareDataWith (k).Resize ({total_k, num_heads, head_size});
206- v_t_s.ShareDataWith (v).Resize ({total_k, num_heads, head_size});
207-
208- DenseTensor cu_seqlens_q;
209- DenseTensor cu_seqlens_k;
210- ArangeNullaryKernel<int32_t , Context>(
211- ctx, 0 , (batch_size + 1 ) * seq_len_q, seq_len_q, &cu_seqlens_q);
212- ArangeNullaryKernel<int32_t , Context>(
213- ctx, 0 , (batch_size + 1 ) * seq_len_k, seq_len_k, &cu_seqlens_k);
214-
215- FlashAttnUnpaddedGradKernel<T, Context>(ctx,
216- q_t_s,
217- k_t_s,
218- v_t_s,
219- cu_seqlens_q,
220- cu_seqlens_k,
221- out,
222- softmax_lse,
223- seed_offset,
224- dout,
225- seq_len_q,
226- seq_len_k,
227- scale,
228- dropout,
229- causal,
230- dq,
231- dk,
232- dv);
181+ const float scale = 1 .0f / std::sqrt (head_size);
182+
183+ FlashAttnBwdParamsV2 params =
184+ FlashAttnBwdParamsV2 (ctx,
185+ batch_size,
186+ seqlen_q,
187+ seqlen_k,
188+ num_heads,
189+ num_heads_k,
190+ head_size,
191+ dropout,
192+ scale,
193+ causal,
194+ q.dtype (),
195+ seed_offset.data <int64_t >());
196+
197+ ctx.template Alloc <T>(dq);
198+ ctx.template Alloc <T>(dk);
199+ ctx.template Alloc <T>(dv);
200+
201+ cudaStream_t stream = ctx.stream ();
233202
203+ VLOG (4 ) << " FlashAttn bwd seed: " << params.seed
204+ << " , offset: " << params.offset ;
205+
206+ const bool succ = phi::dynload::flash_attn_bwd (dout.data (),
207+ q.data (),
208+ k.data (),
209+ v.data (),
210+ out.data (),
211+ params.softmax_d .data (),
212+ softmax_lse.data (),
213+ params.rng_state .data (),
214+ dq->data (),
215+ dk->data (),
216+ dv->data (),
217+ params.dq_accum .data (),
218+ params.batch_size ,
219+ params.max_seqlen_q ,
220+ params.max_seqlen_k ,
221+ params.seqlen_q_rounded ,
222+ params.seqlen_k_rounded ,
223+ params.num_heads ,
224+ params.num_heads_k ,
225+ params.head_size ,
226+ params.head_size_rounded ,
227+ params.dropout ,
228+ params.scale ,
229+ params.causal ,
230+ params.is_bf16 ,
231+ stream,
232+ params.seed ,
233+ params.offset );
234+
235+ PADDLE_ENFORCE_EQ (
236+ succ,
237+ true ,
238+ phi::errors::External (" Error in Flash-Attention-2, detail information is" ,
239+ phi::dynload::flash_attn_error ()));
240+ #else
241+ PADDLE_THROW (phi::errors::Unimplemented (
242+ " FlashAttention is unsupported, please set use_flash_attn to false." ));
234243#endif
235244}
236245
0 commit comments