@@ -136,9 +136,9 @@ def as_numpy(tensor, copy=False):
136136 numpy.ndarray
137137 """
138138 if isinstance (tensor , core .LoDTensorArray ):
139- return [as_numpy (t ) for t in tensor ]
139+ return [as_numpy (t , copy ) for t in tensor ]
140140 if isinstance (tensor , list ):
141- return [as_numpy (t ) for t in tensor ]
141+ return [as_numpy (t , copy ) for t in tensor ]
142142 assert isinstance (tensor , core .LoDTensor )
143143 lod = tensor .lod ()
144144 if len (lod ) > 0 :
@@ -383,6 +383,17 @@ def _to_str(var):
383383 return _to_str (var )
384384
385385
386+ def _is_enable_standalone_executor ():
387+ """
388+ Whether to use experimental executor `StandaloneExecutor`.
389+ """
390+ flag = False
391+ env_val = os .environ .get ('FLAGS_USE_STANDALONE_EXECUTOR' , None )
392+ if env_val in [1 , '1' , True , 'True' , 'true' ]:
393+ flag = True
394+ return flag
395+
396+
386397def _get_strong_program_cache_key (program , feed , fetch_list ):
387398 return str (id (program )) + _get_program_cache_key (feed , fetch_list )
388399
@@ -472,6 +483,121 @@ def handler(self, res_dict):
472483""" )
473484
474485
486+ class _StandaloneExecutor (object ):
487+ def __init__ (self , place , main_program ):
488+ self ._place = core .Place ()
489+ self ._place .set_place (place )
490+ self ._main_program = main_program
491+ self ._new_exe = self ._create_new_executor ()
492+
493+ def run (self , feed , fetch_list , return_numpy = True ):
494+ """
495+ Args:
496+ feed(list|dict): This parameter represents the input Tensors of the model.
497+ If it is single card training, the feed is dict type, and if it is multi-card
498+ training, the parameter feed can be dict or list of Tensors. If the
499+ parameter type is dict, the data in the feed will be split and sent to
500+ multiple devices (CPU/GPU), that is to say, the input data will be evenly
501+ sent to different devices, so you should make sure the number of samples of
502+ the current mini-batch must be greater than the number of places;
503+ if the parameter type is list, those data are copied directly to each device,
504+ so the length of this list should be equal to the number of places.
505+ The default is None.
506+ fetch_list(list): This parameter represents the Tensors that need to be returned
507+ after the model runs. The default is None.
508+ return_numpy(bool): This parameter indicates whether convert the fetched Tensors
509+ (the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
510+ the type of the return value is a list of :code:`LoDTensor`. The default is True.
511+ """
512+ feed = self ._update_feed (feed )
513+ fetch_list = self ._check_fetch (fetch_list )
514+
515+ tensors = self ._new_exe .run (feed , fetch_list )._move_to_list ()
516+ if return_numpy :
517+ return as_numpy (tensors , copy = True )
518+ else :
519+ return tensors
520+
521+ def _create_new_executor (self ):
522+ # NOTE: It's a trick to set empty start_up program.
523+ startup_program = Program ()
524+ outer_scope = global_scope ()
525+ new_exe = core .StandaloneExecutor (self ._place , startup_program .desc ,
526+ self ._main_program .desc , outer_scope )
527+
528+ return new_exe
529+
530+ def _update_feed (self , feed ):
531+ """
532+ Update the feed dict, remove the feed item which is pruned in program.
533+
534+ Notes: This is a very low level API. Users should not use this API
535+ directly.
536+
537+ Args:
538+ feed(list|dict): feed dict or list.
539+
540+ Returns:
541+ feed:(list|dict) updated feed.
542+ """
543+ global_block = self ._main_program .global_block ()
544+ if feed is None :
545+ feed = {}
546+ elif isinstance (feed , dict ):
547+ for feed_name in list (feed .keys ()):
548+ if not global_block .has_var (feed_name ):
549+ feed .pop (feed_name )
550+ warnings .warn (
551+ "The variable %s is not found in program. It is not declared or is pruned."
552+ % feed_name )
553+ else :
554+ raise TypeError ("Only support feed with `dict`, but received {}" .
555+ format (type (feed ).__name__ ))
556+
557+ return feed
558+
559+ def _check_fetch (self , fetch_list ):
560+ if fetch_list is None :
561+ fetch_list = []
562+
563+ res = []
564+ for fetch_var in fetch_list :
565+ if isinstance (fetch_var , Variable ):
566+ fetch_var = fetch_var .name
567+ elif not isinstance (fetch_var , str ):
568+ raise TypeError (
569+ "Required fetch_var shall be str|Variable, but received {}" .
570+ format (type (fetch_var ).__name__ ))
571+
572+ res .append (fetch_var )
573+ return res
574+
575+
576+ class _ExecutorCache (object ):
577+ def __init__ (self , place ):
578+ # {Program : _StandaloneExecutor}
579+ self ._place = place
580+ self ._cached_executors = {}
581+
582+ def run (self , program , feed , fetch_list , return_numpy = True ):
583+ new_exe = self ._get_exe_from_cache (program )
584+ return new_exe .run (feed , fetch_list , return_numpy )
585+
586+ def _get_exe_from_cache (self , program ):
587+ """
588+ Return cached _StandaloneExecutor instance. If not found, create associated
589+ _StandaloneExecutor instance with given program and cache it.
590+ """
591+ assert isinstance (
592+ program , Program ), "Required type(Program), but received {}" .format (
593+ type (program ).__name__ )
594+ if program not in self ._cached_executors :
595+ new_exe = _StandaloneExecutor (self ._place , program )
596+ self ._cached_executors [program ] = new_exe
597+
598+ return self ._cached_executors [program ]
599+
600+
475601class Executor (object ):
476602 """
477603 :api_attr: Static Graph
@@ -568,6 +694,10 @@ def __init__(self, place=None):
568694 self ._auto_checkpoint_name = unique_name .generate (
569695 "__auto_checkpoint_executor__" )
570696
697+ # NOTE: Whether to use experimental executor `StandaloneExecutor`.
698+ self ._enable_interpreter_core = _is_enable_standalone_executor ()
699+ self ._executor_cache = _ExecutorCache (self .place )
700+
571701 def _get_scope_cache (self , program_cache_key ):
572702 return self .scope_caches .get (program_cache_key , None )
573703
@@ -1155,6 +1285,12 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
11551285 if scope is None :
11561286 scope = global_scope ()
11571287
1288+ # NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `,
1289+ # use StandaloneExecutor to run the program.
1290+ if self ._enable_interpreter_core and not program ._is_start_up_program_ :
1291+ return self ._executor_cache .run (program , feed , fetch_list ,
1292+ return_numpy )
1293+
11581294 # use_prune can be overrided by putting optimize_ops in fetch_list
11591295 _origin_fetch_list = fetch_list
11601296 _origin_program = program
0 commit comments