Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG a96f8024714455fb86a326e20c3b7f700ec50772)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1323,15 +1323,15 @@ void FusedRopeGradInferMeta(const MetaTensor& sin,
"[batch_size, seq_len, num_heads, head_dim],"
"but got %u.",
input_dims.size()));
if (dout_q) {
if (dout_q && dq) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
if (dout_k && dk) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
if (dout_v && dv) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
Expand Down
109 changes: 73 additions & 36 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

PD_DECLARE_bool(cudnn_deterministic);

Expand Down Expand Up @@ -51,42 +52,53 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();

const int64_t batch_size = cu_seqlens_q.numel() - 1;
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];

bool is_mha = (num_heads == num_heads_k);

void* dq_ptr = nullptr;
void* dk_ptr = nullptr;
void* dv_ptr = nullptr;

ctx.template Alloc<T>(dq);
dq_ptr = dq->data();
DenseTensor dq_tmp;
if (dq) {
dq_ptr = ctx.template Alloc<T>(dq);
} else {
dq_tmp.Resize(dims);
dq_ptr = ctx.template Alloc<T>(&dq_tmp);
}

std::initializer_list<int64_t> dk_dv_shape = {
total_k, num_heads_k, num_heads / num_heads_k, head_size};

DenseTensor dk_tmp;
if (dk) {
if (dk && is_mha) {
ctx.template Alloc<T>(dk);
dk_ptr = dk->data();
} else {
dk_tmp = EmptyLike<T, Context>(ctx, k);
dk_ptr = dk_tmp.data();
dk_tmp.Resize(dk_dv_shape);
dk_ptr = ctx.template Alloc<T>(&dk_tmp);
}

DenseTensor dv_tmp;
if (dv) {
if (dv && is_mha) {
ctx.template Alloc<T>(dv);
dv_ptr = dv->data();
} else {
dv_tmp = EmptyLike<T, Context>(ctx, v);
dv_ptr = dv_tmp.data();
dv_tmp.Resize(dk_dv_shape);
dv_ptr = ctx.template Alloc<T>(&dv_tmp);
}

const cudaStream_t stream = ctx.stream();

// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();

const int64_t batch_size = cu_seqlens_q.numel() - 1;
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t num_heads_k = k.dims()[1];

int num_splits = get_num_split();

// TODO(umiswing): add shape check
Expand Down Expand Up @@ -150,6 +162,14 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.attn_mask_tensor ? params.mask_dims.data() : nullptr);
CheckFlashAttnStatus(succ);
if (!is_mha) {
if (dk) {
phi::SumKernel<T, Context>(ctx, dk_tmp, {2}, dk->type(), false, dk);
}
if (dv) {
phi::SumKernel<T, Context>(ctx, dv_tmp, {2}, dv->type(), false, dv);
}
}
#else
RaiseNotSupportedError();
#endif
Expand All @@ -171,44 +191,53 @@ void FlashAttnGradKernel(const Context& ctx,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q, k, v [batch_size, seq_len, num_heads, head_dim]
const auto& dims = q.dims();

const int64_t batch_size = dims[0];
const int64_t seqlen_q = dims[1];
const int64_t num_heads = dims[2];
const int64_t head_size_og = dout.dims()[3];
const int64_t head_size = dims[3];
const int64_t seqlen_k = k.dims()[1];
const int64_t num_heads_k = k.dims()[2];

bool is_mha = (num_heads == num_heads_k);

void* dq_ptr = nullptr;
void* dk_ptr = nullptr;
void* dv_ptr = nullptr;

ctx.template Alloc<T>(dq);
dq_ptr = dq->data();
DenseTensor dq_tmp;
if (dq) {
dq_ptr = ctx.template Alloc<T>(dq);
} else {
dq_tmp.Resize(dims);
dq_ptr = ctx.template Alloc<T>(&dq_tmp);
}

DenseTensor dk_tmp;
if (dk) {
std::initializer_list<int64_t> dk_dv_shape = {
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
if (dk && is_mha) {
ctx.template Alloc<T>(dk);
dk_ptr = dk->data();
} else {
dk_tmp = EmptyLike<T, Context>(ctx, k);
dk_ptr = dk_tmp.data();
dk_tmp.Resize(dk_dv_shape);
dk_ptr = ctx.template Alloc<T>(&dk_tmp);
}

DenseTensor dv_tmp;
if (dv) {
if (dv && is_mha) {
ctx.template Alloc<T>(dv);
dv_ptr = dv->data();
} else {
dv_tmp = EmptyLike<T, Context>(ctx, v);
dv_ptr = dv_tmp.data();
dv_tmp.Resize(dk_dv_shape);
dv_ptr = ctx.template Alloc<T>(&dv_tmp);
}

const cudaStream_t stream = ctx.stream();

// q, k, v [batch_size, seq_len, num_heads, head_dim]
const auto& dims = q.dims();

const int64_t batch_size = dims[0];
const int64_t seqlen_q = dims[1];
const int64_t num_heads = dims[2];
const int64_t head_size_og = dout.dims()[3];
const int64_t head_size = dims[3];
const int64_t seqlen_k = k.dims()[1];
const int64_t num_heads_k = k.dims()[2];

// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
Expand Down Expand Up @@ -281,6 +310,14 @@ void FlashAttnGradKernel(const Context& ctx,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.attn_mask_tensor ? params.mask_dims.data() : nullptr);
CheckFlashAttnStatus(succ);
if (!is_mha) {
if (dk) {
phi::SumKernel<T, Context>(ctx, dk_tmp, {3}, dk->type(), false, dk);
}
if (dv) {
phi::SumKernel<T, Context>(ctx, dv_tmp, {3}, dv->type(), false, dv);
}
}
#else
RaiseNotSupportedError();
#endif
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/flash_attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct FlashAttnParamsBase {
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
softmax_scale(_scale),
causal(_causal),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,8 @@ def send_forward_backward_recv_forward_backward(
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()

self._send_meta(output_tensor)
if output_tensor is not None:
self._send_meta(output_tensor)
if recv_prev:
self._recv_meta()

Expand Down
Loading