diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index b91a509b7a1f5f..3cd4120f89a1b6 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -1070,18 +1070,25 @@ ir::Tensor SliceSymbolic(const ir::Tensor& A, input_shape.emplace_back(shape); } - std::vector new_starts(starts); + std::vector new_starts; + std::transform(starts.begin(), + starts.end(), + std::back_inserter(new_starts), + [](const int start) { return ir::Expr(start); }); + for (int i = 0; i < axes.size(); i++) { - CHECK(input_shape[axes[i]].is_constant()) - << "Not supported Slice in dynamic dimensions, because the " - "relationship between slice range and symbol size cannot be " - "determined at compile time"; - if (new_starts[i] < -input_shape[axes[i]].as_int64()) { - new_starts[i] = 0; - } else if (new_starts[i] < 0) { - new_starts[i] = input_shape[axes[i]].as_int64() + new_starts[i]; - } else if (new_starts[i] > input_shape[axes[i]].as_int64()) { - new_starts[i] = input_shape[axes[i]].as_int64() - 1; + if (input_shape[axes[i]].is_constant()) { + if (new_starts[i].as_int64() < -input_shape[axes[i]].as_int64()) { + new_starts[i] = ir::Expr(0); + } else if (new_starts[i].as_int64() < 0) { + new_starts[i] = input_shape[axes[i]].as_int64() + new_starts[i]; + } else if (new_starts[i].as_int64() > input_shape[axes[i]].as_int64()) { + new_starts[i] = input_shape[axes[i]].as_int64() - ir::Expr(1); + } + } else { + if (new_starts[i].as_int64() < 0) { + new_starts[i] = ir::Add::Make(input_shape[axes[i]], new_starts[i]); + } } } diff --git a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py index 82272b4a0f59a9..597d64837fcb80 100644 --- a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py +++ b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py @@ -86,7 +86,7 @@ def test_eval_symbolic(self): ) input_spec = [x_spec] - net = apply_to_static(net, True, input_spec) + net = apply_to_static(net, False, input_spec) net.eval() check_infer_results(net, input_spec, 'pd_op.slice', self.expected)