Skip to content

Commit cc9a768

Browse files
umiswingkuizhiqing
andauthored
[cherry-pick] Integration flash attention 2 (#56015)
* [FlashAttn] add flash randomness control (#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <[email protected]>
1 parent 8d3a988 commit cc9a768

File tree

13 files changed

+682
-344
lines changed

13 files changed

+682
-344
lines changed

cmake/external/flashattn.cmake

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ include(ExternalProject)
1717
add_definitions(-DPADDLE_WITH_FLASHATTN)
1818

1919
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
20-
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
20+
set(FLASHATTN_SOURCE_SUBDIR csrc)
2121
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
2222
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
23-
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
23+
set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
2424

2525
set(FLASHATTN_INCLUDE_DIR
2626
"${FLASHATTN_INSTALL_DIR}/include"
@@ -62,7 +62,7 @@ else()
6262
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
6363
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
6464
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
65-
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
65+
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
6666
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
6767
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
6868
endif()
@@ -93,6 +93,8 @@ ExternalProject_Add(
9393
-DBUILD_SHARED=ON
9494
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
9595
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
96+
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
97+
-DCMAKE_JOB_POOLS:STRING=compile=4
9698
${EXTERNAL_OPTIONAL_ARGS}
9799
CMAKE_CACHE_ARGS
98100
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}

cmake/third_party.cmake

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,15 @@ if(WITH_GPU
512512
list(APPEND third_party_deps extern_cutlass)
513513
set(WITH_CUTLASS ON)
514514
endif()
515-
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2)
516-
include(external/flashattn)
517-
list(APPEND third_party_deps extern_flashattn)
518-
set(WITH_FLASHATTN ON)
515+
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
516+
foreach(arch ${NVCC_ARCH_BIN})
517+
if(${arch} GREATER_EQUAL 80)
518+
include(external/flashattn)
519+
list(APPEND third_party_deps extern_flashattn)
520+
set(WITH_FLASHATTN ON)
521+
break()
522+
endif()
523+
endforeach()
519524
endif()
520525
endif()
521526

paddle/phi/api/yaml/backward.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@
617617
inplace : (out_grad -> x_grad)
618618

619619
- backward_op : flash_attn_grad
620-
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
620+
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
621621
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
622622
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
623623
infer_meta :
@@ -628,7 +628,7 @@
628628
data_type: q
629629

630630
- backward_op : flash_attn_unpadded_grad
631-
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
631+
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
632632
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
633633
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
634634
infer_meta :

paddle/phi/api/yaml/ops.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,9 @@
678678
backward : fill_diagonal_tensor_grad
679679

680680
- op : flash_attn
681-
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
681+
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
682682
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
683+
optional : fixed_seed_offset
683684
infer_meta :
684685
func : FlashAttnInferMeta
685686
param : [q, k, v]
@@ -690,8 +691,9 @@
690691
backward : flash_attn_grad
691692

692693
- op : flash_attn_unpadded
693-
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
694+
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
694695
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
696+
optional : fixed_seed_offset
695697
infer_meta :
696698
func : FlashAttnInferMeta
697699
param : [q, k, v]

paddle/phi/backends/dynload/flashattn.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ extern void* flashattn_dso_handle;
4343
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
4444
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
4545

46-
#define FLASHATTN_ROUTINE_EACH(__macro) \
47-
__macro(flash_attn_fwd); \
48-
__macro(flash_attn_bwd); \
46+
#define FLASHATTN_ROUTINE_EACH(__macro) \
47+
__macro(flash_attn_fwd); \
48+
__macro(flash_attn_varlen_fwd); \
49+
__macro(flash_attn_bwd); \
50+
__macro(flash_attn_varlen_bwd); \
51+
__macro(flash_attn_fwd_with_bias_and_mask); \
52+
__macro(flash_attn_bwd_with_bias_and_mask); \
4953
__macro(flash_attn_error);
5054

5155
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);

paddle/phi/kernels/flash_attn_kernel.h

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,38 @@
2020
namespace phi {
2121

2222
template <typename T, typename Context>
23-
void FlashAttnUnpaddedKernel(const Context& ctx,
24-
const DenseTensor& q,
25-
const DenseTensor& k,
26-
const DenseTensor& v,
27-
const DenseTensor& cu_seqlens_q,
28-
const DenseTensor& cu_seqlens_k,
29-
int64_t max_seqlen_q,
30-
int64_t max_seqlen_k,
31-
float scale,
32-
float dropout,
33-
bool causal,
34-
bool return_softmax,
35-
bool is_test,
36-
DenseTensor* out,
37-
DenseTensor* softmax,
38-
DenseTensor* softmax_lse,
39-
DenseTensor* seed_offset);
23+
void FlashAttnUnpaddedKernel(
24+
const Context& ctx,
25+
const DenseTensor& q,
26+
const DenseTensor& k,
27+
const DenseTensor& v,
28+
const DenseTensor& cu_seqlens_q,
29+
const DenseTensor& cu_seqlens_k,
30+
const paddle::optional<DenseTensor>& fixed_seed_offset,
31+
int64_t max_seqlen_q,
32+
int64_t max_seqlen_k,
33+
float scale,
34+
float dropout,
35+
bool causal,
36+
bool return_softmax,
37+
bool is_test,
38+
const std::string& rng_name,
39+
DenseTensor* out,
40+
DenseTensor* softmax,
41+
DenseTensor* softmax_lse,
42+
DenseTensor* seed_offset);
4043

4144
template <typename T, typename Context>
4245
void FlashAttnKernel(const Context& ctx,
4346
const DenseTensor& q,
4447
const DenseTensor& k,
4548
const DenseTensor& v,
49+
const paddle::optional<DenseTensor>& fixed_seed_offset,
4650
float dropout,
4751
bool causal,
4852
bool return_softmax,
4953
bool is_test,
54+
const std::string& rng_name,
5055
DenseTensor* out,
5156
DenseTensor* softmax,
5257
DenseTensor* softmax_lse,

0 commit comments

Comments
 (0)