Skip to content

Commit 600fc2f

Browse files
authored
add graph_key to specific graph's varmap (#60567)
* add graph_key to specific graph's varmap * fix inpalce case * fix inpalce case
1 parent 823b94e commit 600fc2f

File tree

6 files changed

+39
-12
lines changed

6 files changed

+39
-12
lines changed

paddle/fluid/framework/io/save_paddle2cinn_varmap.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ namespace framework {
2121

2222
void save_paddle2cinn_varmap(
2323
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
24+
int64_t graph_compilation_key,
2425
std::string save_path) {
2526
std::stringstream ss;
27+
ss << "graph_compilation_key:" << std::to_string(graph_compilation_key)
28+
<< "\n";
2629
for (const auto& kv : paddle2cinn_var_map) {
2730
ss << kv.first << ":" << kv.second << "\n";
2831
}

paddle/fluid/framework/io/save_paddle2cinn_varmap.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace framework {
1818

1919
void save_paddle2cinn_varmap(
2020
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
21+
int64_t graph_compilation_key,
2122
std::string save_path);
2223

2324
}

paddle/fluid/framework/io/save_runtime_graph.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ void save_string(std::string content,
3535
fout.close();
3636
}
3737

38+
void save_graph_compilation_key(int64_t graph_compilation_key,
39+
std::string type,
40+
std::string saved_path) {
41+
VLOG(6) << type << " will be saved to " << saved_path;
42+
MkDirRecursively(DirName(saved_path).c_str());
43+
44+
std::ofstream fout(saved_path);
45+
PADDLE_ENFORCE_EQ(
46+
static_cast<bool>(fout),
47+
true,
48+
phi::errors::Unavailable("Cannot open %s to save ", saved_path));
49+
fout << std::to_string(graph_compilation_key);
50+
fout.close();
51+
}
52+
3853
std::string node_format(const ir::Node& node, int number) {
3954
return "node_" + std::to_string(number) + " : " + "[" + node.Name() + ", " +
4055
(node.IsOp() ? "op" : "var") + "]";
@@ -78,6 +93,7 @@ void save_graph(const ir::Graph& graph,
7893
}
7994

8095
void save_runtime_cinn_graph(const ir::Graph& graph,
96+
int64_t graph_compilation_key,
8197
std::string clusters_ops,
8298
std::string clusters_inputs,
8399
std::string cluster_outputs,
@@ -91,7 +107,9 @@ void save_runtime_cinn_graph(const ir::Graph& graph,
91107
save_string(cluster_intervals,
92108
"cluster_intervals",
93109
saved_path + "/cluster_intervals.txt");
94-
110+
save_graph_compilation_key(graph_compilation_key,
111+
"graph_compilation_key",
112+
saved_path + "/graph_compilation_key.txt");
95113
save_graph(graph, "graph", saved_path + "/subgraph.txt");
96114
}
97115

paddle/fluid/framework/io/save_runtime_graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
namespace paddle {
1515
namespace framework {
1616
void save_runtime_cinn_graph(const ir::Graph& graph,
17+
int64_t graph_compilation_key,
1718
std::string clusters_ops,
1819
std::string clusters_inputs,
1920
std::string cluster_outputs,

paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -753,20 +753,21 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
753753
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
754754
sub_skip_gc_vars = all_skip_gc_vars;
755755
}
756+
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
757+
VLOG(4) << "Compilation Key:\n"
758+
<< cinn_compiler->ReadableKey(compilation_key);
759+
756760
if (FLAGS_save_static_runtime_data) {
757761
paddle::framework::save_runtime_cinn_graph(
758-
*subgraph,
762+
cinn_compiler->FindGraph(compilation_key),
763+
compilation_key,
759764
cluster_debug_info(cluster_set),
760765
cluster_debug_info(cluster_inputs),
761766
cluster_debug_info(cluster_outputs),
762767
cluster_debug_info(cluster_internals),
763768
FLAGS_static_runtime_data_save_path + "/cluster_" +
764769
std::to_string(++i));
765770
}
766-
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
767-
VLOG(4) << "Compilation Key:\n"
768-
<< cinn_compiler->ReadableKey(compilation_key);
769-
770771
// Replace the found cluster to a new cinn op node
771772
ReplaceSubGraphWithCinnOpNode(cluster_set,
772773
cluster_inputs,

paddle/fluid/operators/cinn/cinn_launch_context.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
7777
[](const auto& name_view) { return std::string(name_view.data()); });
7878
// build name map between the original variables and compiled ones
7979
BuildVarNameMap(compiled_obj.paddle2cinn_varmap, cinn_argument_names_);
80+
if (FLAGS_save_static_runtime_data) {
81+
auto graph_compilation_key =
82+
std::hash<const framework::ir::Graph*>()((&graph));
83+
paddle::framework::save_paddle2cinn_varmap(
84+
paddle2cinn_varmap_,
85+
graph_compilation_key,
86+
FLAGS_static_runtime_data_save_path +
87+
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
88+
}
8089

8190
const auto& input_var_names =
8291
graph.Get<std::vector<std::string>>(framework::paddle2cinn::kInputVars);
@@ -193,12 +202,6 @@ void CinnLaunchContext::BuildVarNameMap(
193202
"Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
194203
paddle2cinn_varmap_.size(),
195204
cinn2paddle_varmap_.size()));
196-
if (FLAGS_save_static_runtime_data) {
197-
paddle::framework::save_paddle2cinn_varmap(
198-
paddle2cinn_varmap_,
199-
FLAGS_static_runtime_data_save_path +
200-
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
201-
}
202205
}
203206

204207
std::unordered_set<std::string> CinnLaunchContext::GetVisibleVarNames() const {

0 commit comments

Comments
 (0)