Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/codegen/common/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) {

// create capacity from subgraph
std::unique_ptr<ComputeCapability> ToCapacity(const onnxruntime::GraphViewer& graph,
int fused_count,
std::unique_ptr<IndexedSubGraph>& subgraph) {
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
static int fuse_count = 0;
meta_def->name = "Fuse" + std::to_string(fuse_count++);
meta_def->name = "Fuse" + std::to_string(fused_count);
meta_def->domain = "Fuse";

std::set<NodeIndex> node_indices(subgraph->nodes.begin(), subgraph->nodes.end());
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/codegen/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ bool IsAliasNode(const onnxruntime::Node& node);

// Helper function that creates ComputeCapability for subgraphs
std::unique_ptr<ComputeCapability> ToCapacity(const onnxruntime::GraphViewer& graph,
int fused_count,
std::unique_ptr<IndexedSubGraph>& subgraph);

bool IsFusedNode(const Node& node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace onnxruntime {
thread_local int64_t NupharSubgraphUnit::counter = 0;

thread_local std::unique_ptr<std::unordered_map<std::string, int64_t>> NupharExecutionProvider::tls_realized_dims_;
int NupharExecutionProvider::global_fused_count_ = 0;

static std::string GetCurrentHostTargetString() {
#if USE_TVM_WITH_LLVM
Expand Down Expand Up @@ -311,7 +312,12 @@ NupharExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
};
GraphPartitioner graph_partitioner(is_supported_func);

ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, results).IsOK());
ORT_ENFORCE(graph_partitioner.Partition(graph_viewer, global_fused_count_, results).IsOK());

// reset global_fused_count_ for main graph, since there might be multiple sessions for subgraphs,
// this is the time all graph cut should be finished as ORT handles main graph last
if (!graph_viewer.IsSubgraph())
global_fused_count_ = 0;

// for any node being fused in results, save initializer tensors
// because IExecutionProvider::Compile would be called without OpKernelInfo
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/nuphar/nuphar_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class NupharExecutionProvider : public IExecutionProvider {

mutable std::unordered_map<std::string, std::unique_ptr<Tensor>> constant_initializers_used_in_compiled_nodes_;
mutable std::unordered_map<std::string, int> domain_versions_;

// used to create unique fused node name, make it static because
// subsession may create multiple instances of EPs
static int global_fused_count_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ bool GraphPartitioner::ForcePartition(

// Partition the graph (fusing ops) based on the dependency and whether ops are supported:
Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
int& fused_count,
std::vector<std::unique_ptr<ComputeCapability>>& result) {
// call partition
ORT_RETURN_IF_ERROR(Evaluate(graph, /*distinguish_subgraph*/ true));
Expand Down Expand Up @@ -170,9 +171,9 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
if (codegen::CodeGenSettings::Instance().HasOption(kNupharDumpPartition)) {
std::ostringstream stream;
if (graph.IsSubgraph()) {
stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## " << std::endl;
stream << "[NUPHAR_DUMP_PARTITION] ## Subgraph ## Fused graph ID " << fused_count << std::endl;
} else {
stream << "[NUPHAR_DUMP_PARTITION]" << std::endl;
stream << "[NUPHAR_DUMP_PARTITION] ## Fused graph ID " << fused_count << std::endl;
}
stream << "Partition of size " << iter.second.nodes.size() << " [";
for (const auto& node_index : partition->nodes) {
Expand All @@ -186,6 +187,7 @@ Status GraphPartitioner::Partition(const onnxruntime::GraphViewer& graph,
result.emplace_back(
ToCapacity(
graph,
fused_count++,
partition));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class GraphPartitioner : public Partitioner {
: Partitioner(), is_op_type_supported_func_(is_op_type_supported_func) {}

Status Partition(const onnxruntime::GraphViewer& graph,
int& fused_count,
std::vector<std::unique_ptr<ComputeCapability>>& result);

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ Status SubgraphPartitioner::Partition(

if (codegen::CodeGenSettings::Instance().HasOption(kNupharDumpFusedNodes)) {
std::ostringstream stream;
stream << "[NUPHAR_DUMP_FUSED_NODES]" << std::endl;
stream << "[NUPHAR_DUMP_FUSED_NODES] ID " << subgraph.UniqueId() << std::endl;
stream << "NupharSubgraphUnit of size " << results.back().nodes.size() << " [";
for (const auto& n : results.back().nodes) {
stream << "(" << n->Name() << ", " << n->OpType() << ") ";
Expand Down