Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx,
std::map<std::string, std::vector<std::string>> inputs = op_desc.Inputs();
std::vector<std::string> input_types;
for (const auto& pair : inputs) {
if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") {
if (pair.first != "x") {
continue;
}
}
Comment on lines +308 to +312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里下个PR可以改进下,通过表来配置,避免硬编码,最好复用现在已有的信息

VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]);
PADDLE_ENFORCE_NE(
var_desc,
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,25 @@
attrs:
data_format: data_layout

- op : sparse_slice
int_array :
starts :
data_type : int
tensor_name : StartsTensor
tensors_name : StartsTensorList
ends :
data_type : int
tensor_name : EndsTensor
tensors_name : EndsTensorList

- op : sparse_sum
attrs:
axis: axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里删除,相同的不要配置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

scalar :
axis :
data_type : int
tensor_name : AxisTensor

- op : sparse_sync_batch_norm
attrs:
data_format: data_layout
Expand Down
7 changes: 7 additions & 0 deletions test/deprecated/legacy_test/test_sparse_slice_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'):
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)

@compare_legacy_with_pt
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')

@compare_legacy_with_pt
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
Expand Down
2 changes: 2 additions & 0 deletions test/deprecated/legacy_test/test_sparse_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None):
)
paddle.disable_static()

@compare_legacy_with_pt
def test_sum(self):
# 1d
self.check_result_coo([5], None, False)
Expand Down