Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,14 @@ bool FullOpInferSymbolicShape(pir::Operation *op,
sym_shape.push_back(dim_expr);
}

// DimExpr only keep shape info, which always be int type
int64_t value = attributes.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>();
std::vector<symbol::DimExpr> sym_data;
sym_data.emplace_back(value);

symbol::ShapeOrDataDimExprs shape_data{sym_shape};

op->set_attribute(
Expand Down Expand Up @@ -601,6 +609,54 @@ bool Unsqueeze_OpInferSymbolicShape(

bool TileOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_x = op->operand_source(0);
symbol::ShapeOrDataDimExprs x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(operand_x);
pir::Value operand_repeat_times = op->operand_source(1);
symbol::ShapeOrDataDimExprs repeat_times_shape_or_data =
shape_analysis->GetShapeOrDataForValue(operand_repeat_times);

std::vector<symbol::DimExpr> x_dimexpr;
if (x_shape_or_data.data().has_value()) {
x_dimexpr = x_shape_or_data.data().value();
} else {
x_dimexpr = x_shape_or_data.shape();
}

std::vector<symbol::DimExpr> repeat_times_dimexpr;
if (repeat_times_shape_or_data.data().has_value()) {
repeat_times_dimexpr = repeat_times_shape_or_data.data().value();
} else {
repeat_times_dimexpr = repeat_times_shape_or_data.shape();
}
if (repeat_times_dimexpr.empty()) {
repeat_times_dimexpr = std::vector<symbol::DimExpr>(x_dimexpr.size(), 1);
}

auto out_rank = std::max(static_cast<size_t>(x_dimexpr.size()),
repeat_times_dimexpr.size());
std::vector<symbol::DimExpr> out_shape(out_rank);
if (x_dimexpr.size() > repeat_times_dimexpr.size()) {
auto diff = x_dimexpr.size() - repeat_times_dimexpr.size();
repeat_times_dimexpr.insert(repeat_times_dimexpr.begin(), diff, 1);
} else {
auto diff = repeat_times_dimexpr.size() - x_dimexpr.size();
x_dimexpr.insert(x_dimexpr.begin(), diff, 1);
}

for (size_t i = 0; i < repeat_times_dimexpr.size(); ++i) {
out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i];
}

symbol::ShapeOrDataDimExprs shape_data{out_shape};

op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->SetShapeOrDataForValue(res, shape_data);

return true;
}

Expand Down
60 changes: 60 additions & 0 deletions test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,65 @@ def test_eval_symbolic(self):
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


def unsqueeze_composite(x, axis):
"""define composite rule of op unsqueeze"""
"""using reshape to implement unsqueeze op"""
x_shape = list(x.shape)
axis_list = list(axis)
for i in axis_list:
if i < 0:
i += len(x_shape) + 1
x_shape = (
x_shape[:i]
+ [
1,
]
+ x_shape[i:]
)
out = paddle.reshape(x, x_shape)
return out


class LlamaRepeatKV(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.n_rep = 4

def forward(self, hidden_states):
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
rst_unsqueeze = unsqueeze_composite(hidden_states, [-2])
rst_tile = rst_unsqueeze.tile([1, 1, 1, self.n_rep, 1])
out = rst_tile.reshape(
[batch, slen, num_key_value_heads * self.n_rep, head_dim]
)

return out


class TestCinnDyShapeRepeatKV(TestCinnDyShapeBase):
def prepare_data(self):
self.hidden_states_shape = [1, 300, 32, 128]
self.hidden_states = paddle.randn(
self.hidden_states_shape, dtype="float32"
)
self.hidden_states.stop_gradient = False

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = LlamaRepeatKV()
input_spec = [
InputSpec(shape=[None, None, 32, 128], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.hidden_states)
return out

def test_eval_symbolic(self):
# cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


if __name__ == '__main__':
unittest.main()