Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 229 additions & 1 deletion paddle/phi/infermeta/spmd_rules/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
// [batch_size, num_heads, seq_len_q, seq_len_kv]
std::string softmax_axes = {
batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis};
// [batch_size, num_heads, seq_len_q, seq_len_kv]
// [batch_size, num_heads, seq_len_q]
std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis};

auto q_dist_attr_dst = UnShardTensorDims(q_dist_attr, {1, 3});
Expand Down Expand Up @@ -277,6 +277,234 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
{out, softmax, softmax_lse, seed_offset}};
}

SpmdInfo FlashAttInferSpmdStatic(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
const DistMetaTensor& fixed_seed_offset,
const DistMetaTensor& attn_mask,
float dropout,
bool causal,
bool return_softmax,
bool is_test) {
return FlashAttInferSpmd(q,
k,
v,
fixed_seed_offset,
attn_mask,
dropout,
causal,
return_softmax,
is_test);
}

SpmdInfo FlashAttInferSpmdReverse(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
const DistMetaTensor& fixed_seed_offset,
const DistMetaTensor& attn_mask,
const DistMetaTensor& out,
const DistMetaTensor& softmax,
const DistMetaTensor& softmax_lse,
const DistMetaTensor& seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test) {
// q
// [batch_size, seq_len_q, num_heads, head_dim]
auto q_shape = common::vectorize(q.dims());
auto q_dist_attr = q.dist_attr();

// k
// [batch_size, seq_len_kv, num_heads, head_dim]
auto k_shape = common::vectorize(k.dims());
auto k_dist_attr = k.dist_attr();

// v
// [batch_size, seq_len_kv, num_heads, head_dim]
auto v_shape = common::vectorize(v.dims());
auto v_dist_attr = v.dist_attr();

// fixed_seed_offset
// TODO(liuzhenhai): process fixed_seed_offset、seed_offset、 and attn_mask
auto fixed_seed_offset_dist_attr = fixed_seed_offset.dist_attr();
auto fixed_seed_offset_shape = common::vectorize(fixed_seed_offset.dims());
// attn_mask
auto attn_mask_shape = common::vectorize(attn_mask.dims());
int mask_ndim = attn_mask_shape.size();
auto attn_mask_dist_attr = attn_mask.dist_attr();
int mask_dims_mapping_size = attn_mask_dist_attr.dims_mapping().size();
if (!IsEmpty(attn_mask_shape)) {
PADDLE_ENFORCE_EQ(
mask_ndim,
mask_dims_mapping_size,
phi::errors::InvalidArgument("The Tensor mask's rank [%d] and Its "
"dims_mapping size [%d] are not matched.",
mask_ndim,
mask_dims_mapping_size));
}

// out
// [batch_size, seq_len_q, num_heads, head_dim_v]
auto out_shape = common::vectorize(out.dims());
int out_ndim = out_shape.size();
auto out_dist_attr = v.dist_attr();
int out_dims_mapping_size = out_dist_attr.dims_mapping().size();
PADDLE_ENFORCE_EQ(out_ndim,
4,
phi::errors::InvalidArgument(
"The Tensor out's shape must be [batch_size, "
"seq_len_q, num_heads, head_dim_v]"));

auto batch_size = out_shape[0];
auto seq_len_q = out_shape[1];
auto num_heads = out_shape[2];

PADDLE_ENFORCE_EQ(
out_ndim,
out_dims_mapping_size,
phi::errors::InvalidArgument("The Tensor out's rank [%d] and Its "
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping_size));

// softmax_lse
// [batch_size, num_heads, seq_len_q, seq_len_kv]
auto softmax_lse_shape = common::vectorize(softmax_lse.dims());
int softmax_lse_ndim = softmax_lse_shape.size();
auto softmax_lse_dist_attr = softmax_lse.dist_attr();
int softmax_lse_dims_mapping_size =
softmax_lse_dist_attr.dims_mapping().size();
PADDLE_ENFORCE_EQ(out_ndim,
4,
phi::errors::InvalidArgument(
"The Tensor softmax_lse's shape must be [batch_size, "
"num_heads, seq_len_q, seq_len_kv]"));

PADDLE_ENFORCE_EQ(
softmax_lse_ndim,
softmax_lse_dims_mapping_size,
phi::errors::InvalidArgument("The Tensor softmax_lse's rank [%d] and Its "
"dims_mapping size [%d] are not matched.",
softmax_lse_ndim,
softmax_lse_dims_mapping_size));

auto batch_size_2 = softmax_lse_shape[0];
auto num_heads_2 = softmax_lse_shape[1];
auto seq_len_q_2 = softmax_lse_shape[2];

PADDLE_ENFORCE_EQ(
batch_size,
batch_size_2,
phi::errors::InvalidArgument(
"batch size of Tensor out and softmax_lse is not matched: [] vs []",
batch_size,
batch_size_2));

PADDLE_ENFORCE_EQ(
num_heads,
num_heads_2,
phi::errors::InvalidArgument(
"num heads of Tensor out and softmax_lse is not matched: [] vs []",
num_heads,
num_heads_2));

PADDLE_ENFORCE_EQ(
seq_len_q,
seq_len_q_2,
phi::errors::InvalidArgument(
"seq_len_q of Tensor out and softmax_lse is not matched: [] vs []",
seq_len_q,
seq_len_q_2));

TensorDistAttr seed_offset_dist_attr = fixed_seed_offset.dist_attr();
auto seed_offset_shape = common::vectorize(seed_offset.dims());

TensorDistAttr softmax_dist_attr = softmax.dist_attr();
auto softmax_shape = common::vectorize(softmax.dims());

std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
int used_axes_index = 0;
char batch_axis = alphabet[used_axes_index++];
char seq_len_q_axis = alphabet[used_axes_index++];
char num_heads_axis = alphabet[used_axes_index++];
char head_dim_axis = alphabet[used_axes_index++];
char seq_len_kv_axis = alphabet[used_axes_index++];
char head_dim_v_axis = alphabet[used_axes_index++];

// [batch_size, seq_len_q, num_heads, head_dim]
std::string q_axes = {
batch_axis, seq_len_q_axis, num_heads_axis, head_dim_axis};
// [batch_size, seq_len_kv, num_heads, head_dim]
std::string k_axes = {
batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_axis};
// [batch_size, seq_len_kv, num_heads, head_dim_v]
std::string v_axes = {
batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_v_axis};
// [batch_size, seq_len_q, num_heads, head_dim_v]
std::string out_axes = {
batch_axis, seq_len_q_axis, num_heads_axis, head_dim_v_axis};
// [batch_size, num_heads, seq_len_q, seq_len_kv]
std::string softmax_axes = {
batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis};
// [batch_size, num_heads, seq_len_q]
std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis};

auto out_dist_attr_dst = UnShardTensorDims(out_dist_attr, {1, 3});
auto softmax_lse_dist_attr_dst =
UnShardTensorDims(softmax_lse_dist_attr, {2});

std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;

axes_sharding_info.emplace_back(out_axes, out_dist_attr_dst.dims_mapping());
axes_sharding_info.emplace_back(softmax_lse_axes,
softmax_lse_dist_attr_dst.dims_mapping());

auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info);

auto q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes);
auto k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes);
auto v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes);
out_dist_attr_dst = MapDims(out_dist_attr_dst, axis_to_dim_map, out_axes);
softmax_lse_dist_attr_dst =
MapDims(softmax_lse_dist_attr_dst, axis_to_dim_map, softmax_lse_axes);

// TODO(liuzhenhai): process fixed_seed and attn_mask

auto fixed_seed_offset_dist_attr_dst = fixed_seed_offset_dist_attr;
auto attn_mask_dist_attr_dst = attn_mask_dist_attr;
auto softmax_dist_attr_dst = softmax_dist_attr;
auto seed_offset_dist_attr_dst = seed_offset_dist_attr;

VLOG(4) << "FlashAttInferSpmd:";
VLOG(4) << "Einsum Notation: " << q_axes << "," << k_axes << "," << v_axes
<< "-->" << out_axes << "," << softmax_axes << ","
<< softmax_lse_axes;

LOG_SPMD_INPUT(q);
LOG_SPMD_INPUT(k);
LOG_SPMD_INPUT(v);
LOG_SPMD_INPUT(fixed_seed_offset);
LOG_SPMD_INPUT(attn_mask);

VLOG(4) << "Outputs:";
LOG_SPMD_INPUT(out);
LOG_SPMD_INPUT(softmax);
LOG_SPMD_INPUT(softmax_lse);
LOG_SPMD_INPUT(seed_offset);
VLOG(4) << std::endl;

return {{q_dist_attr_dst,
k_dist_attr_dst,
v_dist_attr_dst,
fixed_seed_offset_dist_attr_dst,
attn_mask_dist_attr_dst},
{out_dist_attr_dst,
softmax_dist_attr_dst,
softmax_lse_dist_attr_dst,
seed_offset_dist_attr_dst}};
}

SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/infermeta/spmd_rules/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,30 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
bool is_test = false,
const std::string& rng_name = "");

SpmdInfo FlashAttInferSpmdStatic(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
const DistMetaTensor& fixed_seed_offset,
const DistMetaTensor& attn_mask,
float dropout,
bool causal,
bool return_softmax,
bool is_test);

SpmdInfo FlashAttInferSpmdReverse(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
const DistMetaTensor& fixed_seed_offset,
const DistMetaTensor& attn_mask,
const DistMetaTensor& out,
const DistMetaTensor& softmax,
const DistMetaTensor& softmax_lse,
const DistMetaTensor& seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test);

SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q,
const DistMetaTensor& k,
const DistMetaTensor& v,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ PD_REGISTER_SPMD_RULE(
PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd),
PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
flash_attention,
PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic),
PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdReverse));

// reshape rule
PD_REGISTER_SPMD_RULE(reshape,
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
Expand Down
1 change: 1 addition & 0 deletions test/auto_parallel/spmd_rules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_concat_rule MODULES test_concat_rule)
py_test_modules(test_where_rule MODULES test_where_rule)
py_test_modules(test_triu_rule MODULES test_triu_rule)
py_test_modules(test_flash_attention_rule MODULES test_flash_attention_rule)
# End of unittests WITH single card WITHOUT timeout

endif()
Loading