Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

Commit cc6895e

Browse files
authored
fix one_to_many op Canonicalization (PaddlePaddle#15)
* fix one_to_many op Canonicalization * rename func
1 parent a03acc3 commit cc6895e

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

paddle/fluid/framework/ipu/popart_canonicalization_utils.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,29 @@ SymbolHandler GetHandler(const std::string &kind) {
3838
return {};
3939
}
4040

41+
void MoveNodeInputs(ir::Node *node, ir::Node *new_node) {
42+
new_node->inputs = node->inputs;
43+
for (auto *node_in : node->inputs) {
44+
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
45+
if (node_in->outputs[i] == node) {
46+
node_in->outputs[i] = new_node;
47+
break;
48+
}
49+
}
50+
}
51+
}
52+
53+
void MoveNodeOutputs(ir::Node *node, ir::Node *new_node) {
54+
new_node->outputs = node->outputs;
55+
for (auto *node_out : node->outputs) {
56+
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
57+
if (node_out->inputs[i] == node) {
58+
node_out->inputs[i] = new_node;
59+
break;
60+
}
61+
}
62+
}
63+
}
64+
4165
} // namespace framework
4266
} // namespace paddle

paddle/fluid/framework/ipu/popart_canonicalization_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,8 @@ bool RegisterHandler(const std::string &, const SymbolHandler &);
3636

3737
SymbolHandler GetHandler(const std::string &);
3838

39+
void MoveNodeInputs(ir::Node *node, ir::Node *new_node);
40+
void MoveNodeOutputs(ir::Node *node, ir::Node *new_node);
41+
3942
} // namespace framework
4043
} // namespace paddle

paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,12 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
4343
SymbolHandler handler = GetHandler(op_type);
4444
if (handler) {
4545
new_node = handler(graph, node);
46-
new_node->inputs = node->inputs;
47-
new_node->outputs = node->outputs;
48-
// restore node releations
49-
for (auto* node_in : node->inputs) {
50-
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
51-
if (node_in->outputs[i] == node) {
52-
node_in->outputs[i] = new_node;
53-
break;
54-
}
55-
}
46+
if (new_node->inputs.empty()) {
47+
MoveNodeInputs(node, new_node);
5648
}
57-
for (auto* node_out : node->outputs) {
58-
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
59-
if (node_out->inputs[i] == node) {
60-
node_out->inputs[i] = new_node;
61-
break;
62-
}
63-
}
49+
if (new_node->outputs.empty()) {
50+
MoveNodeOutputs(node, new_node);
6451
}
65-
}
66-
if (new_node) {
6752
graph->RemoveNode(node);
6853
}
6954
}

0 commit comments

Comments
 (0)