Skip to content

Commit 156ceeb

Browse files
committed
follow some comments
1 parent 4497db0 commit 156ceeb

3 files changed

Lines changed: 14 additions & 5 deletions

File tree

paddle/fluid/framework/ir/pass.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ class Graph;
2929
#include "paddle/fluid/platform/mkldnn_helper.h"
3030
#endif
3131

32-
DEFINE_bool(apply_pass_to_program, false,
33-
"Whether to apply IR pass to program");
34-
3532
namespace paddle {
3633
namespace framework {
3734
namespace ir {

paddle/fluid/platform/flags.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,15 @@ DEFINE_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d");
620620
DEFINE_int32(get_host_by_name_time, 120,
621621
"The maximum time for get host by name time");
622622
#endif
623+
624+
/**
625+
* Distributed related FLAG
626+
* Name: FLAGS_apply_pass_to_program
627+
* Since Version: 2.2.0
628+
* Value Range: bool, default=false
629+
* Example: FLAGS_apply_pass_to_program=true, apply IR Pass to When using
630+
* Fleet APIs.
631+
* Note: Apply IR pass to program. Be only useful when using Fleet APIs.
632+
*/
633+
DEFINE_bool(apply_pass_to_program, false,
634+
"Whether to apply IR pass to program when using Fleet APIs");

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle
1919
import os
2020
import numpy as np
21-
from paddle.fluid.framework import dygraph_only
21+
from paddle.fluid.framework import dygraph_only, _global_flags
2222
from paddle.fluid import compiler
2323
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
2424
from .strategy_compiler import StrategyCompiler
@@ -40,7 +40,7 @@
4040

4141
def apply_ir_passes(main_program, startup_program, config):
4242
build_strategy = config._user_defined_strategy.build_strategy._copy()
43-
if not paddle.fluid.core.globals()['FLAGS_apply_pass_to_program']:
43+
if not _global_flags()['FLAGS_apply_pass_to_program']:
4444
return build_strategy
4545

4646
pipeline_opt = getattr(main_program, "_pipeline_opt", {})

0 commit comments

Comments
 (0)