Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
"squeeze",
"stack",
"unsqueeze",
"tile",
]

# come into effect in generated file op_decomp.cc
Expand Down Expand Up @@ -84,7 +83,6 @@
"squeeze",
"stack",
"unsqueeze",
"tile",
]


Expand Down
86 changes: 0 additions & 86 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -990,92 +990,6 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
return std::make_tuple(out, mean_out, var_out);
}

template <typename T>
Tensor tile_decomp(const Tensor& x, const IntArray& repeat_times) {
// x.shape = [3,4] repeat_time=(a,b,c)
// shape1 = [1,3,4]
// shape2 = [1,1,1,3,1,4]
// shape3 = [a,1,b,3,c,4]
// shape4 = shape1 -> [a, b*3, c*4]
// t1 = x.reshape(shape1)
// t2 = t1.reshape(shape2)
// t3 = t2.expand(shape3)
// res = t3.reshape(t3)
std::vector<int64_t> repeat_times_ = repeat_times.GetData();
std::vector<int64_t> shape1 = x.shape();
auto diff = int64_t(repeat_times_.size()) - int64_t(shape1.size());
Tensor t1;
if (has_dynamic_shape(shape1)) {
size_t repeat_time_length = repeat_times_.size();
std::vector<int64_t> unsqueeze_idx2;
if (diff > 0) {
std::vector<int64_t> unsqueeze_idx1(diff);
std::iota(unsqueeze_idx1.begin(), unsqueeze_idx1.end(), 0);
t1 = unsqueeze<T>(x, unsqueeze_idx1);
} else {
t1 = x;
}
auto length2 = t1.dims().size();
for (size_t i = 0; i < repeat_times_.size(); i++) {
unsqueeze_idx2.push_back(length2 - repeat_times_.size() + i * 2);
}

Tensor t2 = unsqueeze<T>(t1, unsqueeze_idx2);
std::vector<int64_t> ref_shape(t2.dims().size(), 1);
for (size_t i = 0; i < unsqueeze_idx2.size(); i++) {
ref_shape[unsqueeze_idx2[i]] = repeat_times_[i];
}
Tensor ref_t = full<T>(ref_shape, 1.0, t2.dtype());
Tensor t3 = t2 * ref_t;
Tensor origin_shape_t = shape<T>(t1);
std::vector<Tensor> res_s;
for (int64_t i = int64_t(length2) - 1; i >= 0; i--) {
auto relative_idx =
int64_t(repeat_time_length) - 1 - int64_t(length2 - i - 1);

if (relative_idx >= 0) {
res_s.insert(
res_s.begin(),
get_slice<T>(origin_shape_t, i) * repeat_times_[relative_idx]);
} else {
res_s.insert(res_s.begin(), get_slice<T>(origin_shape_t, i));
}
}
Tensor s4 = concat<T>(res_s, 0);
return backend::reshape_with_tensor<T>(t3, s4);

} else {
if (diff > 0) {
for (int64_t i = 0; i < diff; i++) {
shape1.insert(shape1.begin(), 1);
}
}

auto length = int64_t(shape1.size());
std::vector<int64_t> shape2 = shape1;
std::vector<int64_t> shape3 = shape1;
std::vector<int64_t> final_shape = shape1;
auto r_length = repeat_times_.size();
for (size_t j = 0; j < repeat_times_.size(); j++) {
int64_t i = int64_t(j);

shape2.insert(shape2.begin() + (length - 1 - i), 1);
shape3.insert(shape3.begin() + (length - 1 - i),
repeat_times_[r_length - i - 1]);

final_shape[length - i - 1] =
final_shape[length - i - 1] * repeat_times_[r_length - i - 1];
}

t1 = reshape<T>(x, shape1);

auto t2 = reshape<T>(t1, shape2);
auto t3 = t2.expand(shape3);
auto res = reshape<T>(t3, final_shape);
return res;
}
}

template <typename T>
Tensor square_decomp(const Tensor& x) {
auto org_dtype = x.dtype();
Expand Down
36 changes: 0 additions & 36 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,6 @@ def stack_net(x):
return paddle.stack([x, y], axis=0)


def tile_net1(x):
y = paddle.tile(x, repeat_times=[2, 5])
return y


def tile_net2(x):
y = paddle.tile(x, repeat_times=[3, 2, 5])
return y


def index_sample_net(x, index):
return paddle.index_sample(x, index)

Expand Down Expand Up @@ -251,32 +241,6 @@ def setUp(self):
self.tol = 1e-6


class TestPrimTile(TestPrimBase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = tile_net1
self.necessary_ops = "pd_op.tile"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimTile2(TestPrimBase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [300, 4096]
self.init_x_shape = [None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = tile_net2
self.necessary_ops = "pd_op.tile"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimTwo(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down