Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 50 additions & 82 deletions paddlenlp/transformers/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@ def wait(self):

def add_to_buffers(self, key, value):
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
k_buffer_chunk = paddle.slice(
self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]]
)
v_buffer_chunk = paddle.slice(
self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]]
)
k_buffer_chunk += key
v_buffer_chunk += value
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(key)
self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(value)
else:
self._k_buffer[self._next_buffer_idx] += key
self._v_buffer[self._next_buffer_idx] += value
self._k_buffer[self._next_buffer_idx].add_(key)
self._v_buffer[self._next_buffer_idx].add_(value)

def get_buffers(self):
return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx]
Expand All @@ -84,23 +78,21 @@ def send_recv(self):


def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False):
if old_out is None and old_lse is None:
return block_out.to("float32"), block_lse.to("float32")

if second_chunk_only:
second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]])
second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]])
second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :]
second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :]
second_chunk_out, second_chunk_lse = update_out_and_lse(
second_chunk_out_, second_chunk_lse_, block_out, block_lse
second_chunk_out, second_chunk_lse, block_out, block_lse
)
paddle.assign(second_chunk_out, second_chunk_out_)
paddle.assign(second_chunk_lse, second_chunk_lse_)
old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse
return old_out, old_lse
else:
block_out, block_lse = block_out.to("float32"), block_lse.to("float32")
with paddle.amp.auto_cast(enable=False, dtype="bfloat16"):
lse = old_lse - F.log_sigmoid(old_lse - block_lse)
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse
block_out, block_lse = paddle.cast(block_out, "float32"), paddle.cast(block_lse, "float32")
with paddle.amp.auto_cast(enable=False):
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid(
old_lse - block_lse
)


def get_chunk_id(rank, cp_size):
Expand Down Expand Up @@ -130,14 +122,10 @@ def balanced_ring_flash_attention_fwd_func(
comm_buffer = RingCommunicator(group, local_key, local_value)
local_q_seq_len = local_query.shape[1]

out, lse, k_cache, v_cache = None, None, dict(), dict()

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
for step in range(cp_size):
block_k, block_v = comm_buffer.get_buffers()

Expand All @@ -159,16 +147,19 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

if step == 0:
out, lse = block_out, block_lse
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
else:
# block_k and block_v is from rank (group.rank - step) % cp_size
if step == 0:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, ""
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = block_out, block_lse
elif step > rank:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query_second_chunk,
Expand All @@ -182,16 +173,14 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)]
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse, True)
else:
block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query,
block_k,
block_v,
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
fixed_seed_offset,
None,
dropout,
Expand All @@ -200,23 +189,19 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
k_cache[step] = block_k
v_cache[step] = block_v

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# comm_buffer.wait()
paddle.device.synchronize()

out = out.to(local_query.dtype)
lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1])
return out, lse, k_cache, v_cache
return paddle.cast(out, local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1])


def balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
local_query,
local_key,
Expand All @@ -240,17 +225,10 @@ def balanced_ring_flash_attention_bwd_func(
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)

if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_out_second_chunk = paddle.slice(
local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
query_grad_buffer_second_chunk = paddle.slice(
query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :]
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :]
out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :]

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
Expand All @@ -274,13 +252,13 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)
else:
if step == 0:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)
elif step > rank:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query_second_chunk,
Expand All @@ -294,12 +272,12 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer_second_chunk += block_q_grad
query_grad_buffer[:, local_q_seq_len // 2 :, :, :].add_(block_q_grad)
else:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query,
k_cache[step],
v_cache[step],
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
local_out,
lse,
fixed_seed_offset,
Expand All @@ -308,9 +286,12 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# kv_comm_buffer.wait()
# if step != 0:
# grad_comm_buffer.wait()
paddle.device.synchronize()

grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
Expand All @@ -319,8 +300,7 @@ def balanced_ring_flash_attention_bwd_func(
grad_comm_buffer.wait()
key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers()

dtype = local_query.dtype
return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype)
return query_grad_buffer, key_grad_buffer, value_grad_buffer


class RingFlashAttention(PyLayer):
Expand All @@ -344,10 +324,10 @@ def forward(
if attn_mask is not None:
is_causal = False

out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func(
out, lse = balanced_ring_flash_attention_fwd_func(
group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training
)
ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache)
ctx.save_for_backward(query, key, value, out, lse, attn_mask)
ctx.group = group
ctx.fixed_seed_offset = fixed_seed_offset
ctx.dropout = dropout
Expand All @@ -356,7 +336,7 @@ def forward(

@staticmethod
def backward(ctx, out_grad):
query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor()
query, key, value, out, lse, attn_mask = ctx.saved_tensor()
group = ctx.group
fixed_seed_offset = ctx.fixed_seed_offset
dropout = ctx.dropout
Expand All @@ -366,19 +346,7 @@ def backward(ctx, out_grad):
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64)

query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
query,
key,
value,
out,
lse,
fixed_seed_offset,
attn_mask,
dropout,
is_causal,
group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal
)
if attn_mask is not None and not attn_mask.stop_gradient:
return query_grad, key_grad, value_grad, None
Expand Down
9 changes: 4 additions & 5 deletions tests/transformers/test_ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, us
)
ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask)

local_out.mean().backward()
ref_out.mean().backward()
local_out.backward()
ref_out.backward()

ref_local_query_grad = self.split_belanced_data(query.grad)
ref_local_key_grad = self.split_belanced_data(key.grad)
ref_local_value_grad = self.split_belanced_data(value.grad)

ref_local_out = self.split_belanced_data(ref_out)

rtol = 1e-04
atol = 5e-03
rtol = 1e-02
atol = 1e-02
np.testing.assert_allclose(
local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol
)
Expand Down