Skip to content

Commit a1aae04

Browse files
authored
[Inference] Replace unordered_map with map to support subgraph stability (#35147)
1 parent e4a8815 commit a1aae04

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

paddle/fluid/framework/ir/subgraph_detector.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ void SubgraphDetector::MarkNodesInsideSubGraph() {
117117
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
118118
// a's output is node b, that is a and b is in the same sub-graph. The UF
119119
// algorithm will group them to the same cluster.
120-
using node_map_t = std::unordered_map<int, Node *>;
120+
using node_map_t = std::map<int, Node *>;
121121
// Find the ancestor id of a node.
122122
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
123123
int tmp = id;
@@ -155,7 +155,7 @@ struct BriefNode {
155155
// 3. change all the dst's inputs and outputs
156156
// corresponding inlinks and outlinks to src node.
157157
// 4. delete all dst's inlinks and outlinks.
158-
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
158+
void UnionContractedNodes(const std::map<int, BriefNode *> &node_map,
159159
int src_id, int dst_id) {
160160
// merge the two adjacent nodes into one node.
161161
BriefNode *src_node = node_map.at(src_id);
@@ -262,7 +262,7 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
262262
std::vector<Node *> marked_nodes;
263263
// We use brief_node_map to represent the original graph in order to avoid
264264
// changing the original graph.
265-
std::unordered_map<int, BriefNode *> brief_node_map;
265+
std::map<int, BriefNode *> brief_node_map;
266266

267267
std::unordered_set<int32_t> valid_node_ids;
268268
for (auto *node : graph_->Nodes()) {

python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def build_program(self, dtype):
167167

168168
self.append_gradients(tmp_3)
169169

170-
self.num_fused_ops = 4
170+
self.num_fused_ops = 3
171171
self.fetch_list = [tmp_3, self.grad(tmp_0)]
172172

173173

0 commit comments

Comments
 (0)