Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
'fused_elemwise_add_activation',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fused_token_prune',
'fused_dconv_drelu_dbn',
'fused_dot_product_attention',
'nce',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,14 @@
func : fused_softmax_mask_upper_triangle
backward: fused_softmax_mask_upper_triangle_grad

- op : fused_token_prune
args : (Tensor attn, Tensor x, Tensor mask, Tensor new_mask, bool keep_first_token = true, bool keep_order = false)
output : Tensor(slimmed_x), Tensor(cls_inds)
infer_meta :
func : FusedTokenPruneInferMeta
kernel:
func : fused_token_prune

- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const std::unordered_set<std::string> LegacyOpList = {
FtrlOp::name(),
FusedElemwiseAddActivationOp::name(),
FusedElemwiseAddActivationGradOp::name(),
FusedTokenPruneOp::name(),
DpsgdOp::name(),
SendV2Op::name(),
RecvV2Op::name(),
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3702,6 +3702,12 @@
outputs :
{out : Out}

- op: fused_token_prune
inputs :
{attn: Attn, x: X, mask: Mask, new_mask: NewMask}
outputs :
{slimmed_x : SlimmedX, cls_inds : CLSInds}

- op: fusion_squared_mat_sub
inputs :
x : X
Expand Down
80 changes: 80 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,86 @@ void FusedRopeInferMeta(const MetaTensor& q,
}
}

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* cls_inds) {
auto mask_dim = mask.dims();
auto attn_dim = attn.dims();
auto x_dim = x.dims();
auto new_mask_dim = new_mask.dims();

PADDLE_ENFORCE_EQ(
mask_dim.size(),
4,
phi::errors::InvalidArgument("The input mask must be 4-dimension"));
PADDLE_ENFORCE_EQ(
attn_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The input x must be 4-dimension"));
PADDLE_ENFORCE_EQ(
new_mask_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(mask_dim[0],
attn_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and attn should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(mask_dim[1],
attn_dim[1],
phi::errors::InvalidArgument(
"The second dim of mask and attn should be the same"
"which is nb_head"));
PADDLE_ENFORCE_EQ(mask_dim[0],
x_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and x should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(
mask_dim[2],
mask_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(
attn_dim[2],
attn_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
mask_dim[2],
phi::errors::InvalidArgument(
"The third dim of mask and attn should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
x_dim[1],
phi::errors::InvalidArgument(
"The third dim of mask and the second dim of attn"
"should be the same which is max seq len"));

auto bsz = mask_dim[0];
auto c = x_dim[2];
auto slim_seq_len = new_mask_dim[2];

std::vector<int64_t> slimmed_x_dims({bsz, slim_seq_len, c});
slimmed_x->set_dims(common::make_ddim(slimmed_x_dims));
slimmed_x->set_dtype(x.dtype());

std::vector<int64_t> cls_inds_dims({bsz, slim_seq_len});
cls_inds->set_dims(common::make_ddim(cls_inds_dims));
cls_inds->set_dtype(phi::DataType::INT64);
}

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,15 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* cls_inds);

void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ test_fused_fc_elementwise_layernorm_op
test_fused_feedforward_op
test_fused_gate_attention_op
test_fused_multihead_matmul_op
test_fused_token_prune_op
test_fusion_seqexpand_concat_fc_op
test_fusion_transpose_flatten_concat_op
test_gather_nd_op
Expand Down