diff --git a/onnxruntime/core/codegen/common/common.cc b/onnxruntime/core/codegen/common/common.cc index 2ce09514acfb9..36e83dbd72105 100644 --- a/onnxruntime/core/codegen/common/common.cc +++ b/onnxruntime/core/codegen/common/common.cc @@ -117,10 +117,10 @@ const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) { // create capacity from subgraph std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, + int fused_count, std::unique_ptr& subgraph) { auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - static int fuse_count = 0; - meta_def->name = "Fuse" + std::to_string(fuse_count++); + meta_def->name = "Fuse" + std::to_string(fused_count); meta_def->domain = "Fuse"; std::set node_indices(subgraph->nodes.begin(), subgraph->nodes.end()); diff --git a/onnxruntime/core/codegen/common/common.h b/onnxruntime/core/codegen/common/common.h index 11ad05325a381..fd27d54f42ac6 100644 --- a/onnxruntime/core/codegen/common/common.h +++ b/onnxruntime/core/codegen/common/common.h @@ -113,6 +113,7 @@ bool IsAliasNode(const onnxruntime::Node& node); // Helper function that creates ComputeCapability for subgraphs std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, + int fused_count, std::unique_ptr& subgraph); bool IsFusedNode(const Node& node); diff --git a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc index 0c8566aa49c8f..6e76bc45d0ad9 100644 --- a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc +++ b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc @@ -25,6 +25,7 @@ namespace onnxruntime { thread_local int64_t NupharSubgraphUnit::counter = 0; thread_local std::unique_ptr> NupharExecutionProvider::tls_realized_dims_; +int NupharExecutionProvider::global_fused_count_ = 0; static std::string GetCurrentHostTargetString() { #if USE_TVM_WITH_LLVM @@ -311,7 +312,12 @@ NupharExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie }; GraphPartitioner graph_partitioner(is_supported_func); - ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, results).IsOK()); + ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, global_fused_count_, results).IsOK()); + + // reset global_fused_count_ for main graph, since there might be multiple sessions for subgraphs, + // this is the time all graph cut should be finished as ORT handles main graph last + if (!graph_viewer.IsSubgraph()) + global_fused_count_ = 0; // for any node being fused in results, save initializer tensors // because IExecutionProvider::Compile would be called without OpKernelInfo diff --git a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h index 07ca11e4f09da..edfb5aa14ffc3 100644 --- a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h +++ b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h @@ -153,6 +153,10 @@ class NupharExecutionProvider : public IExecutionProvider { mutable std::unordered_map> constant_initializers_used_in_compiled_nodes_; mutable std::unordered_map domain_versions_; + + // used to create unique fused node name, make it static because + // subsession may create multiple instances of EPs + static int global_fused_count_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc b/onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc index 010b4a3bf1d6f..6ca92d6933c51 100644 --- a/onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc +++ b/onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc @@ -140,6 +140,7 @@ bool GraphPartitioner::ForcePartition( // Partition the graph (fusing ops) based on the dependency and whether ops are supported: Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph, + int& fused_count, std::vector>& result) { // call partition ORT_RETURN_IF_ERROR(Evaluate(graph, /*distinguish_subgraph*/ true)); @@ -170,9 +171,9 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph, if (codegen::CodeGenSettings::Instance().HasOption(kNupharDumpPartition)) { std::ostringstream stream; if (graph.IsSubgraph()) { - stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## " << std::endl; + stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## Fused graph ID " << fused_count << std::endl; } else { - stream << "[NUPHAR_DUMP_PARTITION]" << std::endl; + stream << "[NUPHAR_DUMP_PARTITION] ## Fused graph ID " << fused_count << std::endl; } stream << "Partition of size " << iter.second.nodes.size() << " ["; for (const auto& node_index : partition->nodes) { @@ -186,6 +187,7 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph, result.emplace_back( ToCapacity( graph, + fused_count++, partition)); } diff --git a/onnxruntime/core/providers/nuphar/partition/graph_partitioner.h b/onnxruntime/core/providers/nuphar/partition/graph_partitioner.h index 5ab4035899264..197a89ee8698f 100644 --- a/onnxruntime/core/providers/nuphar/partition/graph_partitioner.h +++ b/onnxruntime/core/providers/nuphar/partition/graph_partitioner.h @@ -22,6 +22,7 @@ class GraphPartitioner : public Partitioner { : Partitioner(), is_op_type_supported_func_(is_op_type_supported_func) {} Status Partition(const onnxruntime::GraphViewer& graph, + int& fused_count, std::vector>& result); private: diff --git a/onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.cc b/onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.cc index 29c5339f6ebef..ce32f3080abce 100644 --- a/onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.cc +++ b/onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.cc @@ -387,7 +387,7 @@ Status SubgraphPartitioner::Partition( if (codegen::CodeGenSettings::Instance().HasOption(kNupharDumpFusedNodes)) { std::ostringstream stream; - stream << "[NUPHAR_DUMP_FUSED_NODES]" << std::endl; + stream << "[NUPHAR_DUMP_FUSED_NODES] ID " << subgraph.UniqueId() << std::endl; stream << "NupharSubgraphUnit of size " << results.back().nodes.size() << " ["; for (const auto& n : results.back().nodes) { stream << "(" << n->Name() << ", " << n->OpType() << ") ";