Skip to content

Commit 0473369

Browse files
authored
[WIP] Integration flash attention 2 (#55758)
* Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80.
1 parent 785684a commit 0473369

File tree

8 files changed

+491
-325
lines changed

8 files changed

+491
-325
lines changed

cmake/external/flashattn.cmake

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ include(ExternalProject)
1717
add_definitions(-DPADDLE_WITH_FLASHATTN)
1818

1919
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
20-
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
20+
set(FLASHATTN_SOURCE_SUBDIR csrc)
2121
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
2222
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
2323
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
@@ -62,7 +62,7 @@ else()
6262
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
6363
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
6464
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
65-
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
65+
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
6666
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
6767
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
6868
endif()
@@ -92,6 +92,8 @@ ExternalProject_Add(
9292
-DBUILD_SHARED=ON
9393
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
9494
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
95+
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
96+
-DCMAKE_JOB_POOLS:STRING=compile=4
9597
${EXTERNAL_OPTIONAL_ARGS}
9698
CMAKE_CACHE_ARGS
9799
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}

cmake/third_party.cmake

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,15 @@ if(WITH_GPU
548548
list(APPEND third_party_deps extern_cutlass)
549549
set(WITH_CUTLASS ON)
550550
endif()
551-
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2)
552-
include(external/flashattn)
553-
list(APPEND third_party_deps extern_flashattn)
554-
set(WITH_FLASHATTN ON)
551+
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
552+
foreach(arch ${NVCC_ARCH_BIN})
553+
if(${arch} GREATER_EQUAL 80)
554+
include(external/flashattn)
555+
list(APPEND third_party_deps extern_flashattn)
556+
set(WITH_FLASHATTN ON)
557+
break()
558+
endif()
559+
endforeach()
555560
endif()
556561
endif()
557562

paddle/phi/backends/dynload/flashattn.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ extern void* flashattn_dso_handle;
4545

4646
#define FLASHATTN_ROUTINE_EACH(__macro) \
4747
__macro(flash_attn_fwd); \
48+
__macro(flash_attn_varlen_fwd); \
4849
__macro(flash_attn_bwd); \
50+
__macro(flash_attn_varlen_bwd); \
4951
__macro(flash_attn_fwd_with_bias_and_mask); \
5052
__macro(flash_attn_bwd_with_bias_and_mask); \
5153
__macro(flash_attn_error);

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 150 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#ifdef PADDLE_WITH_FLASHATTN
2727
#include "paddle/phi/backends/dynload/flashattn.h"
28+
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
2829
#endif
2930

3031
DECLARE_bool(cudnn_deterministic);
@@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
5556
ctx.template Alloc<T>(dk);
5657
ctx.template Alloc<T>(dv);
5758

58-
cudaStream_t stream = ctx.stream();
59-
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
59+
const cudaStream_t stream = ctx.stream();
6060

6161
// q,k,v [total_*, num_heads, head_dim]
6262

6363
auto dims = q.dims();
64-
int64_t total_q = dims[0];
65-
int64_t num_heads = dims[1];
66-
int64_t head_size = dims[2];
67-
68-
int64_t total_k = k.dims()[0];
69-
int64_t batch_size = cu_seqlens_q.numel() - 1;
70-
71-
int num_splits = 0; // 0 for an internal heuristic, which is optimal
72-
if (FLAGS_cudnn_deterministic) {
73-
num_splits = 1;
74-
}
75-
bool zero_tensors = false;
76-
77-
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
78-
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
79-
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
80-
81-
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
82-
<< ", num_splits:" << num_splits;
83-
84-
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
85-
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
86-
87-
uint64_t workspace_size;
88-
89-
// calculate workspace size before execution
90-
bool succ = phi::dynload::flash_attn_bwd(
91-
q.data(),
92-
k.data(),
93-
v.data(),
94-
dq->data(),
95-
dk->data(),
96-
dv->data(),
97-
nullptr, // for calculation workspace size
98-
dout.data(),
99-
cu_seqlens_q.data(),
100-
cu_seqlens_k.data(),
101-
total_q,
102-
total_k,
103-
batch_size,
104-
num_heads,
105-
head_size,
106-
max_seqlen_q,
107-
max_seqlen_k,
108-
dropout,
109-
scale,
110-
zero_tensors,
111-
causal,
112-
is_bf16,
113-
num_splits,
114-
const_cast<float*>(softmax_lse.data<float>()),
115-
dsoftmax.data(),
116-
nullptr,
117-
&workspace_size,
118-
stream,
119-
seed,
120-
offset);
121-
122-
if (!succ) {
123-
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
124-
}
125-
126-
DenseTensor workspace;
127-
if (workspace_size > 0) {
128-
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
129-
}
130-
131-
succ = phi::dynload::flash_attn_bwd(
132-
q.data(),
133-
k.data(),
134-
v.data(),
135-
dq->data(),
136-
dk->data(),
137-
dv->data(),
138-
out.data(),
139-
dout.data(),
140-
cu_seqlens_q.data(),
141-
cu_seqlens_k.data(),
142-
total_q,
143-
total_k,
144-
batch_size,
145-
num_heads,
64+
const int64_t total_q = dims[0];
65+
const int batch_size = cu_seqlens_q.numel() - 1;
66+
const int num_heads = dims[1];
67+
const int head_size_og = dout.dims()[2];
68+
const int head_size = dims[2];
69+
const int total_k = k.dims()[0];
70+
const int num_heads_k = k.dims()[1];
71+
72+
// TODO(umiswing): add deterministic in fa2.
73+
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
74+
// if (FLAGS_cudnn_deterministic) {
75+
// num_splits = 1;
76+
// }
77+
78+
const bool zero_tensors = false;
79+
80+
// TODO(umiswing): add shape check
81+
PADDLE_ENFORCE_EQ(
82+
head_size_og,
14683
head_size,
147-
max_seqlen_q,
148-
max_seqlen_k,
149-
dropout,
150-
scale,
151-
zero_tensors,
152-
causal,
153-
is_bf16,
154-
num_splits,
155-
const_cast<float*>(softmax_lse.data<float>()),
156-
dsoftmax.data(),
157-
workspace_size > 0 ? workspace.data() : nullptr,
158-
&workspace_size,
159-
stream,
160-
seed,
161-
offset);
84+
phi::errors::InvalidArgument(
85+
"flash_attn_bwd receive input with head_size_og == head_size"));
86+
87+
FlashAttnBwdParamsV2 params =
88+
FlashAttnBwdParamsV2(ctx,
89+
batch_size,
90+
max_seqlen_q,
91+
max_seqlen_k,
92+
num_heads,
93+
num_heads_k,
94+
head_size,
95+
dropout,
96+
scale,
97+
causal,
98+
q.dtype(),
99+
seed_offset.data<int64_t>());
100+
101+
VLOG(4) << "FlashAttn bwd seed: " << params.seed
102+
<< ", offset: " << params.offset;
103+
104+
const bool succ =
105+
phi::dynload::flash_attn_varlen_bwd(dout.data(),
106+
q.data(),
107+
k.data(),
108+
v.data(),
109+
out.data(),
110+
params.softmax_d.data(),
111+
softmax_lse.data(),
112+
cu_seqlens_q.data<int32_t>(),
113+
cu_seqlens_k.data<int32_t>(),
114+
params.rng_state.data(),
115+
dq->data(),
116+
dk->data(),
117+
dv->data(),
118+
params.dq_accum.data(),
119+
params.batch_size,
120+
params.max_seqlen_q,
121+
params.max_seqlen_k,
122+
params.seqlen_q_rounded,
123+
params.seqlen_k_rounded,
124+
params.num_heads,
125+
params.num_heads_k,
126+
params.head_size,
127+
params.head_size_rounded,
128+
params.dropout,
129+
params.scale,
130+
params.causal,
131+
params.is_bf16,
132+
stream,
133+
params.seed,
134+
params.offset);
162135

163136
if (!succ) {
164137
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
165138
}
166-
139+
#else
140+
PADDLE_THROW(phi::errors::Unimplemented(
141+
"FlashAttention is unsupported, please set use_flash_attn to false."));
167142
#endif
168143
}
169144

@@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
185160
// q,k,v [batch_size, seq_len, num_heads, head_dim]
186161

187162
auto dims = q.dims();
188-
int64_t batch_size = dims[0];
189-
int64_t seq_len_q = dims[1];
190-
int64_t num_heads = dims[2];
191-
int64_t head_size = dims[3];
192-
193-
int64_t seq_len_k = k.dims()[1];
194-
195-
int64_t total_q = batch_size * seq_len_q;
196-
int64_t total_k = batch_size * seq_len_k;
197-
198-
float scale = 1.0f / std::sqrt(head_size);
163+
const int batch_size = dims[0];
164+
const int seqlen_q = dims[1];
165+
const int num_heads = dims[2];
166+
const int head_size_og = dout.dims()[3];
167+
const int head_size = dims[3];
168+
const int seqlen_k = k.dims()[1];
169+
const int num_heads_k = k.dims()[2];
170+
171+
// TODO(umiswing): add shape check
172+
PADDLE_ENFORCE_EQ(
173+
head_size_og,
174+
head_size,
175+
phi::errors::InvalidArgument(
176+
"flash_attn_bwd receive input with head_size_og == head_size"));
199177

200178
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
201179
<< "], v[" << v.dims() << "]";
202180

203-
DenseTensor q_t_s, k_t_s, v_t_s;
204-
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
205-
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
206-
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
207-
208-
DenseTensor cu_seqlens_q;
209-
DenseTensor cu_seqlens_k;
210-
ArangeNullaryKernel<int32_t, Context>(
211-
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
212-
ArangeNullaryKernel<int32_t, Context>(
213-
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
214-
215-
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
216-
q_t_s,
217-
k_t_s,
218-
v_t_s,
219-
cu_seqlens_q,
220-
cu_seqlens_k,
221-
out,
222-
softmax_lse,
223-
seed_offset,
224-
dout,
225-
seq_len_q,
226-
seq_len_k,
227-
scale,
228-
dropout,
229-
causal,
230-
dq,
231-
dk,
232-
dv);
181+
const float scale = 1.0f / std::sqrt(head_size);
182+
183+
FlashAttnBwdParamsV2 params =
184+
FlashAttnBwdParamsV2(ctx,
185+
batch_size,
186+
seqlen_q,
187+
seqlen_k,
188+
num_heads,
189+
num_heads_k,
190+
head_size,
191+
dropout,
192+
scale,
193+
causal,
194+
q.dtype(),
195+
seed_offset.data<int64_t>());
196+
197+
ctx.template Alloc<T>(dq);
198+
ctx.template Alloc<T>(dk);
199+
ctx.template Alloc<T>(dv);
200+
201+
cudaStream_t stream = ctx.stream();
233202

203+
VLOG(4) << "FlashAttn bwd seed: " << params.seed
204+
<< ", offset: " << params.offset;
205+
206+
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
207+
q.data(),
208+
k.data(),
209+
v.data(),
210+
out.data(),
211+
params.softmax_d.data(),
212+
softmax_lse.data(),
213+
params.rng_state.data(),
214+
dq->data(),
215+
dk->data(),
216+
dv->data(),
217+
params.dq_accum.data(),
218+
params.batch_size,
219+
params.max_seqlen_q,
220+
params.max_seqlen_k,
221+
params.seqlen_q_rounded,
222+
params.seqlen_k_rounded,
223+
params.num_heads,
224+
params.num_heads_k,
225+
params.head_size,
226+
params.head_size_rounded,
227+
params.dropout,
228+
params.scale,
229+
params.causal,
230+
params.is_bf16,
231+
stream,
232+
params.seed,
233+
params.offset);
234+
235+
PADDLE_ENFORCE_EQ(
236+
succ,
237+
true,
238+
phi::errors::External("Error in Flash-Attention-2, detail information is",
239+
phi::dynload::flash_attn_error()));
240+
#else
241+
PADDLE_THROW(phi::errors::Unimplemented(
242+
"FlashAttention is unsupported, please set use_flash_attn to false."));
234243
#endif
235244
}
236245

0 commit comments

Comments
 (0)