Skip to content

Commit f2af76f

Browse files
committed
Remove local static fused_count
added a field global_fused_count_ to NupharExecutionProvider class
1 parent 25b6e76 commit f2af76f

File tree

6 files changed

+19
-8
lines changed

6 files changed

+19
-8
lines changed

onnxruntime/core/codegen/common/common.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) {
117117

118118
// create capacity from subgraph
119119
std::unique_ptr<ComputeCapability> ToCapacity(const onnxruntime::GraphViewer& graph,
120-
int fuse_count,
120+
int fused_count,
121121
std::unique_ptr<IndexedSubGraph>& subgraph) {
122122
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
123-
meta_def->name = "Fuse" + std::to_string(fuse_count);
123+
meta_def->name = "Fuse" + std::to_string(fused_count);
124124
meta_def->domain = "Fuse";
125125

126126
std::set<NodeIndex> node_indices(subgraph->nodes.begin(), subgraph->nodes.end());

onnxruntime/core/codegen/common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ bool IsAliasNode(const onnxruntime::Node& node);
113113

114114
// Helper function that creates ComputeCapability for subgraphs
115115
std::unique_ptr<ComputeCapability> ToCapacity(const onnxruntime::GraphViewer& graph,
116-
int fuse_count,
116+
int fused_count,
117117
std::unique_ptr<IndexedSubGraph>& subgraph);
118118

119119
bool IsFusedNode(const Node& node);

onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace onnxruntime {
2525
thread_local int64_t NupharSubgraphUnit::counter = 0;
2626

2727
thread_local std::unique_ptr<std::unordered_map<std::string, int64_t>> NupharExecutionProvider::tls_realized_dims_;
28+
int NupharExecutionProvider::global_fused_count_ = 0;
2829

2930
static std::string GetCurrentHostTargetString() {
3031
#if USE_TVM_WITH_LLVM
@@ -311,7 +312,12 @@ NupharExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
311312
};
312313
GraphPartitioner graph_partitioner(is_supported_func);
313314

314-
ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, results).IsOK());
315+
ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, global_fused_count_, results).IsOK());
316+
317+
// reset global_fused_count_ for main graph, since there might be multiple sessions for subgraphs,
318+
// this is the time all graph cut should be finished as ORT handles main graph last
319+
if (!graph_viewer.IsSubgraph())
320+
global_fused_count_ = 0;
315321

316322
// for any node being fused in results, save initializer tensors
317323
// because IExecutionProvider::Compile would be called without OpKernelInfo

onnxruntime/core/providers/nuphar/nuphar_execution_provider.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ class NupharExecutionProvider : public IExecutionProvider {
153153

154154
mutable std::unordered_map<std::string, std::unique_ptr<Tensor>> constant_initializers_used_in_compiled_nodes_;
155155
mutable std::unordered_map<std::string, int> domain_versions_;
156+
157+
// used to create unique fused node name, make it static because
158+
// subsession may create multiple instances of EPs
159+
static int global_fused_count_;
156160
};
157161

158162
} // namespace onnxruntime

onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ bool GraphPartitioner::ForcePartition(
140140

141141
// Partition the graph (fusing ops) based on the dependency and whether ops are supported:
142142
Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
143+
int& fused_count,
143144
std::vector<std::unique_ptr<ComputeCapability>>& result) {
144145
// call partition
145146
ORT_RETURN_IF_ERROR(Evaluate(graph, /*distinguish_subgraph*/ true));
@@ -167,13 +168,12 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
167168
partition->nodes.push_back(n);
168169
}
169170

170-
static int fuse_count = 0;
171171
if (codegen::CodeGenSettings::Instance().HasOption(kNupharDumpPartition)) {
172172
std::ostringstream stream;
173173
if (graph.IsSubgraph()) {
174-
stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## Fused graph ID " << fuse_count << std::endl;
174+
stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## Fused graph ID " << fused_count << std::endl;
175175
} else {
176-
stream << "[NUPHAR_DUMP_PARTITION] ## Fused graph ID " << fuse_count << std::endl;
176+
stream << "[NUPHAR_DUMP_PARTITION] ## Fused graph ID " << fused_count << std::endl;
177177
}
178178
stream << "Partition of size " << iter.second.nodes.size() << " [";
179179
for (const auto& node_index : partition->nodes) {
@@ -187,7 +187,7 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
187187
result.emplace_back(
188188
ToCapacity(
189189
graph,
190-
fuse_count++,
190+
fused_count++,
191191
partition));
192192
}
193193

onnxruntime/core/providers/nuphar/partition/graph_partitioner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class GraphPartitioner : public Partitioner {
2222
: Partitioner(), is_op_type_supported_func_(is_op_type_supported_func) {}
2323

2424
Status Partition(const onnxruntime::GraphViewer& graph,
25+
int& fused_count,
2526
std::vector<std::unique_ptr<ComputeCapability>>& result);
2627

2728
private:

0 commit comments

Comments
 (0)