Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
'fused_elemwise_add_activation',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fused_token_prune',
'fused_dconv_drelu_dbn',
'fused_dot_product_attention',
'nce',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,14 @@
func : fused_softmax_mask_upper_triangle
backward: fused_softmax_mask_upper_triangle_grad

- op : fused_token_prune
args : (Tensor attn, Tensor x, Tensor mask, Tensor new_mask, bool keep_first_token = true, bool keep_order = false)
output : Tensor(slimmed_x), Tensor(clsi_nds)
infer_meta :
func : FusedTokenPruneInferMeta
kernel:
func : fused_token_prune

- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const std::unordered_set<std::string> LegacyOpList = {
FtrlOp::name(),
FusedElemwiseAddActivationOp::name(),
FusedElemwiseAddActivationGradOp::name(),
FusedTokenPruneOp::name(),
DpsgdOp::name(),
SendV2Op::name(),
RecvV2Op::name(),
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3702,6 +3702,12 @@
outputs :
{out : Out}

- op: fused_token_prune
inputs :
{attn: Attn, x: X, mask: Mask, new_mask: NewMask}
outputs :
{slimmed_x : SlimmedX, clsi_nds : CLSInds}

- op: fusion_squared_mat_sub
inputs :
x : X
Expand Down
80 changes: 80 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,86 @@ void FusedRopeInferMeta(const MetaTensor& q,
}
}

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* clsi_nds) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MetaTensor* clsi_nds) {
MetaTensor* cls_inds) {

这里的没改过来,辛苦再修改下吧~

auto mask_dim = mask.dims();
auto attn_dim = attn.dims();
auto x_dim = x.dims();
auto new_mask_dim = new_mask.dims();

PADDLE_ENFORCE_EQ(
mask_dim.size(),
4,
phi::errors::InvalidArgument("The input mask must be 4-dimension"));
PADDLE_ENFORCE_EQ(
attn_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The input x must be 4-dimension"));
PADDLE_ENFORCE_EQ(
new_mask_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(mask_dim[0],
attn_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and attn should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(mask_dim[1],
attn_dim[1],
phi::errors::InvalidArgument(
"The second dim of mask and attn should be the same"
"which is nb_head"));
PADDLE_ENFORCE_EQ(mask_dim[0],
x_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and x should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(
mask_dim[2],
mask_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(
attn_dim[2],
attn_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
mask_dim[2],
phi::errors::InvalidArgument(
"The third dim of mask and attn should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
x_dim[1],
phi::errors::InvalidArgument(
"The third dim of mask and the second dim of attn"
"should be the same which is max seq len"));

auto bsz = mask_dim[0];
auto c = x_dim[2];
auto slim_seq_len = new_mask_dim[2];

std::vector<int64_t> slimmed_x_dims({bsz, slim_seq_len, c});
slimmed_x->set_dims(common::make_ddim(slimmed_x_dims));
slimmed_x->set_dtype(x.dtype());

std::vector<int64_t> clsi_nds_dims({bsz, slim_seq_len});
clsi_nds->set_dims(common::make_ddim(clsi_nds_dims));
clsi_nds->set_dtype(phi::DataType::INT64);
}

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,15 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* clsi_nds);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MetaTensor* clsi_nds);
MetaTensor* cls_inds);


void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ test_fused_fc_elementwise_layernorm_op
test_fused_feedforward_op
test_fused_gate_attention_op
test_fused_multihead_matmul_op
test_fused_token_prune_op
test_fusion_seqexpand_concat_fc_op
test_fusion_transpose_flatten_concat_op
test_gather_nd_op
Expand Down