Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
'add_n_',
'all_reduce',
'all_reduce_',
'assign_pos',
'batch_fc',
'barrier',
'c_allgather',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@
inplace : (output -> out)
backward : assign_out__grad

- op : assign_pos
args : (Tensor x, Tensor cum_count, Tensor eff_num_len)
output : Tensor(out)
infer_meta :
func : AssignPosInferMeta
kernel :
func : assign_pos

- op : assign_value
args : (int[] shape, DataType dtype, Scalar[] values, Place place = {})
output : Tensor(out)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@
get_expected_kernel_type :
assign : GetAssignExpectedKernelType

- op : assign_pos
inputs :
{x : X, cum_count : cum_count, eff_num_len : eff_num_len}
outputs :
out : Out

- op : assign_value
outputs :
out : Out
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void AssignPosInferMeta(const MetaTensor& x,
const MetaTensor& cum_count,
const MetaTensor& eff_num_len,
MetaTensor* out) {
phi::DataType X_dtype = x.dtype();
phi::DataType cum_count_dtype = cum_count.dtype();

PADDLE_ENFORCE_EQ(cum_count_dtype,
X_dtype,
phi::errors::InvalidArgument(
"The dtype of the cum_count and X should be same"));
PADDLE_ENFORCE_EQ(cum_count_dtype,
phi::DataType::INT64,
phi::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
out->set_dtype(X_dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要设置ddim吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原本是空的,这里我觉得不改变现状比较好,后续有优化需求单独做

}

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& step,
MetaTensor* out);

void AssignPosInferMeta(const MetaTensor& x,
const MetaTensor& cum_count,
const MetaTensor& eff_num_len,
MetaTensor* out);

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_no_check_list
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
test_assign_pos_op
test_bernoulli_op
test_dirichlet_op
test_empty_op
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ test_arg_min_max_op_static_build
test_arg_min_max_v2_op
test_argsort_op
test_assign_op
test_assign_pos_op
test_assign_value_op
test_atan2_op
test_auc_op
Expand Down