Skip to content

Commit bf06163

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into reconstruct_drr
2 parents 078fa99 + 277fe29 commit bf06163

112 files changed

Lines changed: 5303 additions & 1193 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ English | [简体中文](./README_cn.md) | [日本語](./README_ja.md)
1515
Welcome to the PaddlePaddle GitHub.
1616

1717
PaddlePaddle, as the first independent R&D deep learning platform in China, has been officially open-sourced to professional communities since 2016. It is an industrial platform with advanced technologies and rich features that cover core deep learning frameworks, basic model libraries, end-to-end development kits, tools & components as well as service platforms.
18-
PaddlePaddle is originated from industrial practices with dedication and commitments to industrialization. It has been widely adopted by a wide range of sectors including manufacturing, agriculture, enterprise service, and so on while serving more than 8 million developers, 220,000 companies and generating 800,000 models. With such advantages, PaddlePaddle has helped an increasing number of partners commercialize AI.
18+
PaddlePaddle is originated from industrial practices with dedication and commitments to industrialization. It has been widely adopted by a wide range of sectors including manufacturing, agriculture, enterprise service, and so on while serving more than 10.7 million developers, 235,000 companies and generating 860,000 models. With such advantages, PaddlePaddle has helped an increasing number of partners commercialize AI.
1919

2020
## Installation
2121

README_cn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
欢迎来到 PaddlePaddle GitHub
1616

17-
飞桨(PaddlePaddle)以百度多年的深度学习技术研究和业务应用为基础,是中国首个自主研发、功能完备、 开源开放的产业级深度学习平台,集深度学习核心训练和推理框架、基础模型库、端到端开发套件和丰富的工具组件于一体。目前,飞桨累计开发者800万,服务企业22万家,基于飞桨开源深度学习平台产生了80万个模型。飞桨助力开发者快速实现AI想法,快速上线AI业务。帮助越来越多的行业完成AI赋能,实现产业智能化升级。
17+
飞桨(PaddlePaddle)以百度多年的深度学习技术研究和业务应用为基础,是中国首个自主研发、功能完备、 开源开放的产业级深度学习平台,集深度学习核心训练和推理框架、基础模型库、端到端开发套件和丰富的工具组件于一体。目前,飞桨累计开发者1070万,服务企业23.5万家,基于飞桨开源深度学习平台产生了86万个模型。飞桨助力开发者快速实现AI想法,快速上线AI业务。帮助越来越多的行业完成AI赋能,实现产业智能化升级。
1818

1919
## 安装
2020

README_ja.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
PaddlePaddle GitHub へようこそ。
1616

1717
PaddlePaddle は中国初の独立系 R&D ディープラーニングプラットフォームとして、2016年からプロのコミュニティに正式にオープンソース化されました。コアとなる深層学習フレームワーク、基本モデルライブラリ、エンドツーエンドの開発キット、ツール&コンポーネント、さらにサービスプラットフォームを網羅する、高度な技術と豊富な機能を備えた産業プラットフォームです。
18-
PaddlePaddle は、工業化に対するコミットメントを持つ工業的実践から生まれたものです。製造業、農業、企業サービスなど幅広い分野で採用され、800万人以上の開発者、22万以上の企業、80万以上のモデルを生み出しています。それにより PaddlePaddle は、ますます多くのパートナーの AI 商用化を支援しています。
18+
PaddlePaddle は、工業化に対するコミットメントを持つ工業的実践から生まれたものです。製造業、農業、企業サービスなど幅広い分野で採用され、1070万人以上の開発者、23.5万以上の企業、86万以上のモデルを生み出しています。それにより PaddlePaddle は、ますます多くのパートナーの AI 商用化を支援しています。
1919

2020
## インストール
2121

paddle/cinn/common/broadcast_tree.cc

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ using Pattern2Placement = std::unordered_map<symbol::DimExpr, symbol::DimExpr>;
185185
Pattern2Placement ConstructCstrLhsEqRhsReplacement(
186186
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition) {
187187
auto [lhs, rhs] = *broadcastable_condition;
188-
if (lhs.isa<std::string>()) return Pattern2Placement{{lhs, rhs}};
189188
if (rhs.isa<std::string>()) return Pattern2Placement{{rhs, lhs}};
189+
if (lhs.isa<std::string>()) return Pattern2Placement{{lhs, rhs}};
190190
return Pattern2Placement{{lhs, rhs}};
191191
}
192192

@@ -295,4 +295,54 @@ BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves) {
295295
return ConstructBroadcastBranch(broadcastable_condition.value(), leaves);
296296
}
297297

298+
namespace {
299+
300+
std::string ToTxtStringImpl(const BroadcastBranch<BroadcastTree>& branch) {
301+
std::stringstream ss;
302+
const auto& [cstr, lhs_eq_rhs, lhs_eq_one, rhs_eq_one] = branch.tuple();
303+
const auto& [lhs, rhs] = *cstr;
304+
const auto& Put = [&](const std::string& key, const auto& value) {
305+
ss << "\"" << key << "\": ";
306+
ss << ToTxtString(value);
307+
ss << ",\n ";
308+
};
309+
ss << "{";
310+
ss << "\"$lhs\": " << lhs << ",\n ";
311+
ss << "\"$rhs\": " << rhs << ",\n ";
312+
Put("$lhs == $rhs", lhs_eq_rhs);
313+
Put("$lhs == 1", lhs_eq_one);
314+
Put("$rhs == 1", rhs_eq_one);
315+
ss << "}";
316+
return ss.str();
317+
}
318+
319+
std::string ToTxtStringImpl(const BroadcastLeaf& leaf) {
320+
std::stringstream ss;
321+
ss << "[";
322+
for (const auto& dim_exprs : *leaf) {
323+
ss << "[";
324+
int j = 0;
325+
for (const auto& dim_expr : dim_exprs) {
326+
if (j++) {
327+
ss << ",";
328+
}
329+
ss << dim_expr;
330+
}
331+
ss << "]";
332+
}
333+
ss << "]";
334+
return ss.str();
335+
}
336+
337+
} // namespace
338+
339+
std::string ToTxtString(const BroadcastTree& tree) {
340+
return std::visit([&](const auto& impl) { return ToTxtStringImpl(impl); },
341+
tree.variant());
342+
}
343+
344+
std::ostream& operator<<(std::ostream& os, const BroadcastTree& tree) {
345+
os << ToTxtString(tree);
346+
}
347+
298348
} // namespace cinn::common

paddle/cinn/common/broadcast_tree.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ using BroadcastTree = adt::Tree<BroadcastBranch, BroadcastLeaf>;
3131

3232
BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves);
3333

34+
std::string ToTxtString(const BroadcastTree&);
35+
36+
std::ostream& operator<<(std::ostream& os, const BroadcastTree& tree);
37+
3438
} // namespace cinn::common

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ if(NOT CINN_ONLY)
1111
cinn_runtime_dialect
1212
pir_compiler)
1313

14-
cc_library(
15-
cinn_transforms
16-
SRCS ${cinn_transforms_srcs}
17-
DEPS ${cinn_transforms_deps})
14+
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
15+
${cinn_transforms_deps})
1816

1917
endif()

paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h

Lines changed: 0 additions & 35 deletions
This file was deleted.

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,8 @@ bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
138138
pir::ShapeConstraintIRAnalysis& shape_analysis =
139139
pir::ShapeAnalysisManager::Instance().Get(
140140
op.x().defining_op()->GetParentProgram());
141-
CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) !=
142-
shape_analysis.value_id_to_shapeordata_.end());
143-
return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value));
141+
142+
return shape_analysis.GetShapeOrDataForValue(value);
144143
};
145144
std::optional<pir::Value> opt_generated_shape =
146145
GetOutOfRewritedGenerateShapeOp(

paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"
1616

1717
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
1818
#include "paddle/cinn/hlir/framework/pir/utils.h"
@@ -49,6 +49,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
4949
}
5050
pir::Value x = op->operand_source(0);
5151
pir::Value y = op->operand_source(1);
52+
pir::ShapeConstraintIRAnalysis& shape_analysis =
53+
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
54+
const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x);
55+
const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y);
56+
if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) {
57+
return false;
58+
}
59+
5260
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
5361
{
5462
pir::Value broadcasted_x =
@@ -66,7 +74,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
6674
} // namespace
6775

6876
template <typename OPTYPE>
69-
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
77+
class InsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
7078
public:
7179
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;
7280

@@ -76,42 +84,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
7684
}
7785
};
7886

79-
FullyInsertBroadcastPass::FullyInsertBroadcastPass()
80-
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}
81-
82-
pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
83-
pir::IrContext* context) {
84-
pir::RewritePatternSet ps(context);
85-
// elementwise ops
86-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
87-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
88-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
89-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
90-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
91-
context);
92-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
93-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
94-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
95-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
96-
97-
// compare ops
98-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
99-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
100-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
101-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
102-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
103-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
104-
105-
// bitwise ops
106-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
107-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
108-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
109-
110-
return ps;
111-
}
87+
class InsertBroadcastPass : public pir::PatternRewritePass {
88+
public:
89+
InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {}
90+
91+
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
92+
pir::RewritePatternSet ps(context);
93+
// elementwise ops
94+
ps.Add<InsertBroadcastPattern<paddle::dialect::AddOp>>(context);
95+
ps.Add<InsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
96+
ps.Add<InsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
97+
ps.Add<InsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
98+
ps.Add<InsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(context);
99+
ps.Add<InsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
100+
ps.Add<InsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
101+
ps.Add<InsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
102+
ps.Add<InsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
103+
104+
// compare ops
105+
ps.Add<InsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
106+
ps.Add<InsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
107+
ps.Add<InsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
108+
ps.Add<InsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
109+
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
110+
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
111+
112+
// bitwise ops
113+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
114+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
115+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
116+
117+
return ps;
118+
}
119+
120+
bool CanApplyOn(pir::Operation* op) const override {
121+
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
122+
}
123+
};
112124

113-
bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
114-
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
125+
std::unique_ptr<pir::Pass> CreateInsertBroadcastPass() {
126+
return std::make_unique<InsertBroadcastPass>();
115127
}
116128

117129
} // namespace ir

paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h renamed to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
#pragma once
1616

17-
#include <memory>
18-
#include "paddle/pir/core/dll_decl.h"
17+
#include "paddle/pir/pass/pass.h"
1918

20-
namespace pir {
19+
namespace cinn {
20+
namespace dialect {
21+
namespace ir {
2122

22-
class Pass;
23+
IR_API std::unique_ptr<pir::Pass> CreateInsertBroadcastPass();
2324

24-
IR_API std::unique_ptr<Pass> CreateFcWithSpecialOpFusePass();
25-
26-
} // namespace pir
25+
} // namespace ir
26+
} // namespace dialect
27+
} // namespace cinn

0 commit comments

Comments
 (0)