diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index bbb1ec69267ae8..982b9fddc59fa3 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -508,6 +508,14 @@ bool FullOpInferSymbolicShape(pir::Operation *op, sym_shape.push_back(dim_expr); } + // DimExpr only keep shape info, which always be int type + int64_t value = attributes.at("value") + .dyn_cast() + .data() + .to(); + std::vector sym_data; + sym_data.emplace_back(value); + symbol::ShapeOrDataDimExprs shape_data{sym_shape}; op->set_attribute( @@ -601,6 +609,54 @@ bool Unsqueeze_OpInferSymbolicShape( bool TileOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_x = op->operand_source(0); + symbol::ShapeOrDataDimExprs x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_x); + pir::Value operand_repeat_times = op->operand_source(1); + symbol::ShapeOrDataDimExprs repeat_times_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_repeat_times); + + std::vector x_dimexpr; + if (x_shape_or_data.data().has_value()) { + x_dimexpr = x_shape_or_data.data().value(); + } else { + x_dimexpr = x_shape_or_data.shape(); + } + + std::vector repeat_times_dimexpr; + if (repeat_times_shape_or_data.data().has_value()) { + repeat_times_dimexpr = repeat_times_shape_or_data.data().value(); + } else { + repeat_times_dimexpr = repeat_times_shape_or_data.shape(); + } + if (repeat_times_dimexpr.empty()) { + repeat_times_dimexpr = std::vector(x_dimexpr.size(), 1); + } + + auto out_rank = std::max(static_cast(x_dimexpr.size()), + repeat_times_dimexpr.size()); + std::vector out_shape(out_rank); + if (x_dimexpr.size() > repeat_times_dimexpr.size()) { + auto diff = x_dimexpr.size() - repeat_times_dimexpr.size(); + repeat_times_dimexpr.insert(repeat_times_dimexpr.begin(), diff, 1); + } else { + auto diff = repeat_times_dimexpr.size() - x_dimexpr.size(); + x_dimexpr.insert(x_dimexpr.begin(), diff, 1); + } + + for (size_t i = 0; i < repeat_times_dimexpr.size(); ++i) { + out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i]; + } + + symbol::ShapeOrDataDimExprs shape_data{out_shape}; + + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + + pir::OpResult res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; } diff --git a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py index 7d84d55dab0672..738dd79eb840ad 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py @@ -232,5 +232,65 @@ def test_eval_symbolic(self): # np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8) +def unsqueeze_composite(x, axis): + """define composite rule of op unsqueeze""" + """using reshape to implement unsqueeze op""" + x_shape = list(x.shape) + axis_list = list(axis) + for i in axis_list: + if i < 0: + i += len(x_shape) + 1 + x_shape = ( + x_shape[:i] + + [ + 1, + ] + + x_shape[i:] + ) + out = paddle.reshape(x, x_shape) + return out + + +class LlamaRepeatKV(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.n_rep = 4 + + def forward(self, hidden_states): + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + rst_unsqueeze = unsqueeze_composite(hidden_states, [-2]) + rst_tile = rst_unsqueeze.tile([1, 1, 1, self.n_rep, 1]) + out = rst_tile.reshape( + [batch, slen, num_key_value_heads * self.n_rep, head_dim] + ) + + return out + + +class TestCinnDyShapeRepeatKV(TestCinnDyShapeBase): + def prepare_data(self): + self.hidden_states_shape = [1, 300, 32, 128] + self.hidden_states = paddle.randn( + self.hidden_states_shape, dtype="float32" + ) + self.hidden_states.stop_gradient = False + + def eval_symbolic(self, use_cinn): + paddle.seed(2022) + net = LlamaRepeatKV() + input_spec = [ + InputSpec(shape=[None, None, 32, 128], dtype='float32'), + ] + net = apply_to_static(net, use_cinn, input_spec) + net.eval() + out = net(self.hidden_states) + return out + + def test_eval_symbolic(self): + # cinn_out = self.eval_symbolic(use_cinn=True) + dy_out = self.eval_symbolic(use_cinn=False) + # np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8) + + if __name__ == '__main__': unittest.main()