Skip to content

Commit 4bc0853

Browse files
authored
Intergrate StandaloneExecutor in Static.Executor Interface with FLAGS_USE_STANDALONE_EXECUTOR (#35628)
* Intergrate StandaloneExecutor in Static.Executor Interface with FLAGS_USE_STANDALONE_EXECUTOR * Enhance unittest and clean code in StandaloneExecutor * polish unittest
1 parent 0b8664e commit 4bc0853

File tree

5 files changed

+283
-41
lines changed

5 files changed

+283
-41
lines changed

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
117117
info.var_ref_count_ = 0;
118118
info.vardesc_ = var;
119119
var_scope->vec_meta_info_.push_back(info);
120+
} else {
121+
auto var_id = var_scope->name2id[var->Name()];
122+
if (nullptr == var_scope->vec_meta_info_[var_id].vardesc_) {
123+
VLOG(3) << "update var:" << var->Name() << " desc from nullptr into "
124+
<< var;
125+
var_scope->vec_meta_info_[var_id].vardesc_ = var;
126+
}
120127
}
121128
}
122129
}

paddle/fluid/framework/new_executor/standalone_executor.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
3535
auto v = outer_scope_->Var(name);
3636
if (global_scope_.name2id.find(name) == global_scope_.name2id.end()) {
3737
global_scope_.name2id[name] = global_scope_.var_list.size();
38-
}
39-
40-
global_scope_.var_list.push_back(v);
38+
global_scope_.var_list.push_back(v);
4139

42-
VariableMetaInfo info;
43-
info.var_ref_count_ = 0;
44-
info.vardesc_ = nullptr;
45-
global_scope_.vec_meta_info_.push_back(info);
40+
VariableMetaInfo info;
41+
info.var_ref_count_ = 0;
42+
info.vardesc_ = nullptr;
43+
global_scope_.vec_meta_info_.push_back(info);
44+
}
4645
}
4746
}
4847

python/paddle/fluid/executor.py

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def as_numpy(tensor, copy=False):
136136
numpy.ndarray
137137
"""
138138
if isinstance(tensor, core.LoDTensorArray):
139-
return [as_numpy(t) for t in tensor]
139+
return [as_numpy(t, copy) for t in tensor]
140140
if isinstance(tensor, list):
141-
return [as_numpy(t) for t in tensor]
141+
return [as_numpy(t, copy) for t in tensor]
142142
assert isinstance(tensor, core.LoDTensor)
143143
lod = tensor.lod()
144144
if len(lod) > 0:
@@ -383,6 +383,17 @@ def _to_str(var):
383383
return _to_str(var)
384384

385385

386+
def _is_enable_standalone_executor():
387+
"""
388+
Whether to use experimental executor `StandaloneExecutor`.
389+
"""
390+
flag = False
391+
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None)
392+
if env_val in [1, '1', True, 'True', 'true']:
393+
flag = True
394+
return flag
395+
396+
386397
def _get_strong_program_cache_key(program, feed, fetch_list):
387398
return str(id(program)) + _get_program_cache_key(feed, fetch_list)
388399

@@ -472,6 +483,121 @@ def handler(self, res_dict):
472483
""")
473484

474485

486+
class _StandaloneExecutor(object):
487+
def __init__(self, place, main_program):
488+
self._place = core.Place()
489+
self._place.set_place(place)
490+
self._main_program = main_program
491+
self._new_exe = self._create_new_executor()
492+
493+
def run(self, feed, fetch_list, return_numpy=True):
494+
"""
495+
Args:
496+
feed(list|dict): This parameter represents the input Tensors of the model.
497+
If it is single card training, the feed is dict type, and if it is multi-card
498+
training, the parameter feed can be dict or list of Tensors. If the
499+
parameter type is dict, the data in the feed will be split and sent to
500+
multiple devices (CPU/GPU), that is to say, the input data will be evenly
501+
sent to different devices, so you should make sure the number of samples of
502+
the current mini-batch must be greater than the number of places;
503+
if the parameter type is list, those data are copied directly to each device,
504+
so the length of this list should be equal to the number of places.
505+
The default is None.
506+
fetch_list(list): This parameter represents the Tensors that need to be returned
507+
after the model runs. The default is None.
508+
return_numpy(bool): This parameter indicates whether convert the fetched Tensors
509+
(the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
510+
the type of the return value is a list of :code:`LoDTensor`. The default is True.
511+
"""
512+
feed = self._update_feed(feed)
513+
fetch_list = self._check_fetch(fetch_list)
514+
515+
tensors = self._new_exe.run(feed, fetch_list)._move_to_list()
516+
if return_numpy:
517+
return as_numpy(tensors, copy=True)
518+
else:
519+
return tensors
520+
521+
def _create_new_executor(self):
522+
# NOTE: It's a trick to set empty start_up program.
523+
startup_program = Program()
524+
outer_scope = global_scope()
525+
new_exe = core.StandaloneExecutor(self._place, startup_program.desc,
526+
self._main_program.desc, outer_scope)
527+
528+
return new_exe
529+
530+
def _update_feed(self, feed):
531+
"""
532+
Update the feed dict, remove the feed item which is pruned in program.
533+
534+
Notes: This is a very low level API. Users should not use this API
535+
directly.
536+
537+
Args:
538+
feed(list|dict): feed dict or list.
539+
540+
Returns:
541+
feed:(list|dict) updated feed.
542+
"""
543+
global_block = self._main_program.global_block()
544+
if feed is None:
545+
feed = {}
546+
elif isinstance(feed, dict):
547+
for feed_name in list(feed.keys()):
548+
if not global_block.has_var(feed_name):
549+
feed.pop(feed_name)
550+
warnings.warn(
551+
"The variable %s is not found in program. It is not declared or is pruned."
552+
% feed_name)
553+
else:
554+
raise TypeError("Only support feed with `dict`, but received {}".
555+
format(type(feed).__name__))
556+
557+
return feed
558+
559+
def _check_fetch(self, fetch_list):
560+
if fetch_list is None:
561+
fetch_list = []
562+
563+
res = []
564+
for fetch_var in fetch_list:
565+
if isinstance(fetch_var, Variable):
566+
fetch_var = fetch_var.name
567+
elif not isinstance(fetch_var, str):
568+
raise TypeError(
569+
"Required fetch_var shall be str|Variable, but received {}".
570+
format(type(fetch_var).__name__))
571+
572+
res.append(fetch_var)
573+
return res
574+
575+
576+
class _ExecutorCache(object):
577+
def __init__(self, place):
578+
# {Program : _StandaloneExecutor}
579+
self._place = place
580+
self._cached_executors = {}
581+
582+
def run(self, program, feed, fetch_list, return_numpy=True):
583+
new_exe = self._get_exe_from_cache(program)
584+
return new_exe.run(feed, fetch_list, return_numpy)
585+
586+
def _get_exe_from_cache(self, program):
587+
"""
588+
Return cached _StandaloneExecutor instance. If not found, create associated
589+
_StandaloneExecutor instance with given program and cache it.
590+
"""
591+
assert isinstance(
592+
program, Program), "Required type(Program), but received {}".format(
593+
type(program).__name__)
594+
if program not in self._cached_executors:
595+
new_exe = _StandaloneExecutor(self._place, program)
596+
self._cached_executors[program] = new_exe
597+
598+
return self._cached_executors[program]
599+
600+
475601
class Executor(object):
476602
"""
477603
:api_attr: Static Graph
@@ -568,6 +694,10 @@ def __init__(self, place=None):
568694
self._auto_checkpoint_name = unique_name.generate(
569695
"__auto_checkpoint_executor__")
570696

697+
# NOTE: Whether to use experimental executor `StandaloneExecutor`.
698+
self._enable_interpreter_core = _is_enable_standalone_executor()
699+
self._executor_cache = _ExecutorCache(self.place)
700+
571701
def _get_scope_cache(self, program_cache_key):
572702
return self.scope_caches.get(program_cache_key, None)
573703

@@ -1155,6 +1285,12 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
11551285
if scope is None:
11561286
scope = global_scope()
11571287

1288+
# NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `,
1289+
# use StandaloneExecutor to run the program.
1290+
if self._enable_interpreter_core and not program._is_start_up_program_:
1291+
return self._executor_cache.run(program, feed, fetch_list,
1292+
return_numpy)
1293+
11581294
# use_prune can be overrided by putting optimize_ops in fetch_list
11591295
_origin_fetch_list = fetch_list
11601296
_origin_program = program

python/paddle/fluid/framework.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4381,6 +4381,8 @@ def __init__(self):
43814381

43824382
# compiled program, i.e. Graph
43834383
self._graph = None
4384+
# to tag whether is startup_program
4385+
self._is_start_up_program_ = False
43844386

43854387
def _find_var_class_kwargs(self, new_desc):
43864388
# NOTE: not all variables support shape/dtype/lod_level methods.
@@ -5994,6 +5996,7 @@ def _copy_to(self, device, blocking):
59945996
# program is a global instance.
59955997
_main_program_ = Program()
59965998
_startup_program_ = Program()
5999+
_startup_program_._is_start_up_program_ = True
59976000

59986001

59996002
def default_startup_program():
@@ -6142,6 +6145,8 @@ def program_guard(main_program, startup_program=None):
61426145
if startup_program is not None:
61436146
check_type(startup_program, 'startup_program', Program,
61446147
'paddle.static.program_guard')
6148+
# Tag the program __is_start_up as True
6149+
startup_program._is_start_up_program_ = True
61456150
startup_program = switch_startup_program(startup_program)
61466151
try:
61476152
yield

0 commit comments

Comments
 (0)