Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
'prune_gate_by_capacity',
'push_sparse_v2',
'push_sparse_v2_',
'partial_concat',
'partial_send',
'partial_recv',
'partial_allgather',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,16 @@
func : partial_allgather
inplace : (x -> out)

- op : partial_concat
args : (Tensor[] x, int start_index = 0, int length = -1)
output : Tensor(out)
infer_meta :
func : PartialConcatInferMeta
kernel :
func : partial_concat
data_type : x
backward : partial_concat_grad

- op : partial_recv
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,16 @@
composite : pad_grad(x, out_grad, paddings, pad_value, x_grad)
backward : pad_double_grad

- backward_op : partial_concat_grad
forward : partial_concat (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : PartialConcatGradInferMeta
param : [x]
kernel :
func : partial_concat_grad

- backward_op : partial_sum_grad
forward : partial_sum (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ const std::unordered_set<std::string> LegacyOpList = {
SoftReluGradOp::name(),
MatchMatrixTensorOp::name(),
MatchMatrixTensorGradOp::name(),
PartialConcatOp::name(),
PartialConcatGradOp::name(),
NceOp::name(),
NceGradOp::name(),
PartialSumOp::name(),
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,15 @@
outputs :
out : Out

- op : partial_concat
backward : partial_concat_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]

- op : partial_recv
outputs :
out : Out
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,16 @@ void NanmedianGradInferMeta(const MetaTensor& x,
x_grad->set_dtype(x.dtype());
}

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads) {
auto input_num = xs.size();
for (size_t i = 0; i < input_num; i++) {
auto x_dims = xs[i]->dims();
x_grads[i]->set_dims(x_dims);
x_grads[i]->set_dtype(xs[i]->dtype());
}
}

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ void NanmedianGradInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* x_grad);

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

Expand Down
71 changes: 71 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4543,6 +4543,77 @@ void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
out->set_dtype(xs[0]->dtype());
}

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config) {
int64_t batch_size = -1;
int64_t input_len = -1;

auto inputs_num = xs.size();
PADDLE_ENFORCE_GT(inputs_num,
0,
phi::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"received inputs' length is 0."));

// Only support two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
for (size_t i = 0; i < inputs_num; i++) {
auto x_dim = xs[i]->dims();

PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument("Only support two dimensions input now."));

if (i == 0) {
batch_size = x_dim[0];
input_len = x_dim[1];
} else {
// each tensor's dim must eq
PADDLE_ENFORCE_EQ(x_dim[0],
batch_size,
phi::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(x_dim[1],
input_len,
phi::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}

PADDLE_ENFORCE_EQ(
start_index >= -input_len && start_index < input_len,
true,
phi::errors::InvalidArgument(
"The start_index is expected to be in range of [%d, %d), but got %d",
-input_len,
input_len,
start_index));

if (start_index < 0) {
start_index += input_len;
}

if (length > 0) {
PADDLE_ENFORCE_GE(input_len,
start_index + length,
phi::errors::OutOfRange(
"start_index + length is larger than input length"));
}
Comment on lines +4600 to +4605
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么多出这部分判断?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么多出这部分判断?

check了一下,确实不需要,上面的Enforce已经覆盖了这个情况,已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index + length <= input_len 被哪些条件覆盖了,坦白讲我没看出来。这里其实是可以合入的,只是得下补充解释。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外PR等CI过了我就直接合入了,你先不要改,这里的改动你加点解释,如果有必要的话可以再提个PR补充。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index

我原本以为上面的start_index >= -input_len && start_index < input_len,这个判断是判断start_index越界的,然后代入到下面start_index + length这个越界判断,已经覆盖了这个情况,刚意识到start_index + length应该等价于end_index,确实需要判断一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index + length <= input_len 被哪些条件覆盖了,坦白讲我没看出来。这里其实是可以合入的,只是得下补充解释。

int start_index = static_cast<int>(ComputeStartIndex(

我想了一下,按理说,这段的逻辑应该也要加上这个判断,就比如说:
partial_len 如果大于0的话就直推出了partial_len * inputs_num这么多的大小,并没有考虑start_index + length <= input_len 这个条件

Copy link
Contributor Author

@cmcamdy cmcamdy Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个条件

也就是越界的话就是推大了,然后我看了一眼它的compute定义,似乎也没有考虑这个,就是这里:

memcpy(out_data + out_size * j + partial_len * i,

越界了memcpy可能拷贝到下一行的数据了


std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
// colnum = input_num * length
out_dims[1] = (length < 0) ? input_len - start_index : length;
out_dims[1] *= inputs_num;
DDim out_dim = common::make_ddim(out_dims);
out->set_dims(out_dim);
out->set_dtype(xs[0]->dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,12 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ test_one_hot_v2_op
test_one_hot_v2_op_static_build
test_overlap_add_op
test_pad3d_op
test_partial_concat_op
test_partial_sum_op
test_pass_quantization
test_pixel_shuffle_op
Expand Down