Skip to content

Commit c8e6545

Browse files
umiswingkuizhiqing
authored andcommitted
part-3 cherry from: [cherry-pick] Integration flash attention 2 (PaddlePaddle#56015)
* [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <[email protected]>
1 parent cc7279c commit c8e6545

File tree

5 files changed

+182
-39
lines changed

5 files changed

+182
-39
lines changed

cmake/external/flashattn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
2020
set(FLASHATTN_SOURCE_SUBDIR csrc)
2121
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
2222
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
23-
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
23+
set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
2424

2525
set(FLASHATTN_INCLUDE_DIR
2626
"${FLASHATTN_INSTALL_DIR}/include"

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
#include "paddle/phi/core/tensor_utils.h"
2222
#include "paddle/phi/kernels/arange_kernel.h"
2323
#include "paddle/phi/kernels/empty_kernel.h"
24-
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
2524
#include "paddle/phi/kernels/reshape_kernel.h"
2625

26+
#ifdef PADDLE_WITH_FLASHATTN
27+
#include "paddle/phi/backends/dynload/flashattn.h"
28+
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
29+
#endif
30+
2731
PD_DECLARE_bool(cudnn_deterministic);
2832

2933
namespace phi {

paddle/phi/kernels/gpu/flash_attn_kernel.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
#include "paddle/phi/core/tensor_utils.h"
2222
#include "paddle/phi/kernels/arange_kernel.h"
2323
#include "paddle/phi/kernels/empty_kernel.h"
24-
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
2524
#include "paddle/phi/kernels/reshape_kernel.h"
2625

27-
PD_DECLARE_bool(cudnn_deterministic);
26+
#ifdef PADDLE_WITH_FLASHATTN
27+
#include "paddle/phi/backends/dynload/flashattn.h"
28+
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
29+
#endif
2830

2931
namespace phi {
3032

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
import logging
16+
17+
from ...utils.log_utils import get_logger
18+
19+
_logger = get_logger(logging.INFO)
20+
from ..random import determinate_rng, is_enable_auto_rand_ctrl
21+
from .common import (
22+
DistributedOperatorImplContainer,
23+
register_distributed_operator_impl,
24+
register_distributed_operator_impl_container,
25+
)
26+
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
27+
28+
29+
class DistributedFlashAttn(DistributedOperatorImplContainer):
30+
def __init__(self, op_type):
31+
super().__init__(op_type)
32+
33+
34+
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
35+
36+
37+
# Dist FlashAttn with Random Control
38+
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
39+
def __init__(self, name):
40+
super().__init__(name)
41+
self._forward_implemented = True
42+
self._backward_implemented = True
43+
44+
def is_input_compatible(self, dist_op):
45+
return True
46+
47+
def is_output_compatible(self, dist_op):
48+
return True
49+
50+
def is_auto_compatible(self, dist_op):
51+
return True
52+
53+
@staticmethod
54+
def forward(ctx, *args, **kwargs):
55+
dist_op_context = ctx.dist_op_context
56+
main_block = dist_op_context.work_block
57+
startup_block = dist_op_context.startup_block
58+
src_op = dist_op_context.cur_src_op
59+
rank_id = dist_op_context.rank_id
60+
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
61+
62+
if (
63+
is_enable_auto_rand_ctrl()
64+
and not op_dist_attr.is_recompute
65+
and rank_id in op_dist_attr.process_mesh.process_ids
66+
):
67+
assert (
68+
op_dist_attr is not None
69+
), f"forward op [{str(src_op)}] don't have dist attribute !"
70+
71+
if (
72+
len(kwargs.get('fixed_seed_offset', [])) > 0
73+
or len(src_op.input("fixed_seed_offset")) > 0
74+
):
75+
# TODO(kuizhiqing) recompute should go here
76+
pass
77+
else:
78+
# determinate rng
79+
q_var = main_block._var_recursive(kwargs['q'][0])
80+
k_var = main_block._var_recursive(kwargs['k'][0])
81+
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
82+
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
83+
process_mesh = op_dist_attr.process_mesh
84+
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
85+
86+
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
87+
assert rng_name is not None and rng_name != ""
88+
89+
src_op._set_attr('rng_name', rng_name)
90+
91+
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
92+
93+
@staticmethod
94+
def backward(ctx, *args, **kwargs):
95+
# dropout backward is deterministic by mask, and not need for random state control
96+
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
97+
98+
99+
register_distributed_operator_impl(
100+
"flash_attn", DistributedFlashAttnImpl0("random_control")
101+
)

python/paddle/nn/functional/flash_attention.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
import paddle
1618
import paddle.nn.functional as F
1719
from paddle import _C_ops, in_dynamic_mode
@@ -22,6 +24,10 @@
2224
g_enable_flash = None
2325
g_enable_mem_efficient = None
2426

27+
g_use_flash_attn_v1 = (
28+
os.getenv('FLAGS_flash_attn_version', 'v2').strip().lower() == 'v1'
29+
)
30+
2531

2632
@signature_safe_contextmanager
2733
def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
@@ -222,21 +228,32 @@ def flash_attention(
222228

223229
if sdp_func_name == "flash_attn":
224230
if in_dynamic_mode():
225-
(
226-
result_attention,
227-
result_softmax,
228-
) = _C_ops.flash_attn(
229-
query,
230-
key,
231-
value,
232-
fixed_seed_offset,
233-
None,
234-
dropout,
235-
causal,
236-
return_softmax,
237-
not training,
238-
rng_name,
239-
)
231+
if g_use_flash_attn_v1:
232+
(result_attention, result_softmax, _, _) = _C_ops.flash_attn_v1(
233+
query,
234+
key,
235+
value,
236+
dropout,
237+
causal,
238+
return_softmax,
239+
not training,
240+
)
241+
else:
242+
(
243+
result_attention,
244+
result_softmax,
245+
) = _C_ops.flash_attn(
246+
query,
247+
key,
248+
value,
249+
fixed_seed_offset,
250+
None,
251+
dropout,
252+
causal,
253+
return_softmax,
254+
not training,
255+
rng_name,
256+
)
240257
return result_attention, result_softmax if return_softmax else None
241258

242259
helper = LayerHelper('flash_attn', **locals())
@@ -377,26 +394,45 @@ def flash_attn_unpadded(
377394
378395
"""
379396
if in_dynamic_mode():
380-
(
381-
result_attention,
382-
result_softmax,
383-
) = _C_ops.flash_attn_unpadded(
384-
query,
385-
key,
386-
value,
387-
cu_seqlens_q,
388-
cu_seqlens_k,
389-
fixed_seed_offset,
390-
None,
391-
max_seqlen_q,
392-
max_seqlen_k,
393-
scale,
394-
dropout,
395-
causal,
396-
return_softmax,
397-
not training,
398-
rng_name,
399-
)
397+
if g_use_flash_attn_v1:
398+
(
399+
result_attention,
400+
result_softmax,
401+
) = _C_ops.flash_attn_unpadded(
402+
query,
403+
key,
404+
value,
405+
cu_seqlens_q,
406+
cu_seqlens_k,
407+
max_seqlen_q,
408+
max_seqlen_k,
409+
scale,
410+
dropout,
411+
causal,
412+
return_softmax,
413+
not training,
414+
)
415+
else:
416+
(
417+
result_attention,
418+
result_softmax,
419+
) = _C_ops.flash_attn_unpadded(
420+
query,
421+
key,
422+
value,
423+
cu_seqlens_q,
424+
cu_seqlens_k,
425+
fixed_seed_offset,
426+
None,
427+
max_seqlen_q,
428+
max_seqlen_k,
429+
scale,
430+
dropout,
431+
causal,
432+
return_softmax,
433+
not training,
434+
rng_name,
435+
)
400436
return result_attention, result_softmax if return_softmax else None
401437

402438
helper = LayerHelper('flash_attn_unpadded', **locals())

0 commit comments

Comments
 (0)