Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,48 @@ class FullOpPattern : public pir::OpRewritePattern<paddle::dialect::FullOp> {
}
};

class CombineOpPattern : public pir::OpRewritePattern<pir::CombineOp> {
public:
using pir::OpRewritePattern<pir::CombineOp>::OpRewritePattern;

bool Match(pir::CombineOp op) const override {
auto out_type = op.result(0).type().dyn_cast<pir::VectorType>();
for (auto type : out_type.data()) {
if (HasZeroDim(type)) return true;
}
return false;
}

void Rewrite(pir::CombineOp op,
pir::PatternRewriter &rewriter) const override {
pir::Builder builder(rewriter.ir_context());

const std::vector<pir::Type> inputs_type = [&]() {
std::vector<pir::Type> types;
for (auto value : op->operands_source()) {
types.push_back(value.type());
}
return types;
}();
op.result(0).set_type(builder.vec_type(inputs_type));
}

private:
bool HasZeroDim(pir::Type type) const {
if (!type) return false;
const auto dense_tensor_type = type.dyn_cast<pir::DenseTensorType>();
return dense_tensor_type && (dense_tensor_type.dims().size() == 0U);
}
};

class Convert0DTo1DPass : public pir::PatternRewritePass {
public:
Convert0DTo1DPass() : pir::PatternRewritePass("convert_0D_to_1D", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add<FullOpPattern>(context);
ps.Add<CombineOpPattern>(context);

return ps;
}
Expand Down
8 changes: 3 additions & 5 deletions test/ir/pir/cinn/sub_graphs/test_sub_graph_23.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def forward(
var_0, # (shape: [11, 24, 56, 56], dtype: paddle.float32, stop_gradient: False)
var_1, # (shape: [11, 24, 56, 56], dtype: paddle.float32, stop_gradient: False)
):
var_2 = paddle.tensor.attribute.shape(var_0)
var_3 = var_2[0]
var_3 = var_0.shape[0]
var_4 = paddle.tensor.random.rand(shape=[var_3, 1, 1, 1])
var_5 = 0.975 + var_4
var_6 = paddle.tensor.ops.floor(var_5)
Expand Down Expand Up @@ -65,16 +64,15 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
outs = net(*self.inputs)
return outs

# NOTE prim + cinn lead to error
def test_ast_prim_cinn(self):
st_out = self.train(self.net, to_static=True)
cinn_out = self.train(
self.net, to_static=True, with_prim=True, with_cinn=False
self.net, to_static=True, with_prim=True, with_cinn=True
)
for st, cinn in zip(
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)


if __name__ == '__main__':
Expand Down