Skip to content

Commit f91942f

Browse files
committed
pipeline add program cache
1 parent 4032c2e commit f91942f

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

python/paddle/fluid/executor.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,10 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
11361136
program = program._pipeline_opt["startup_program"]
11371137
else:
11381138
return self.train_from_dataset(program, fetch_list=fetch_list)
1139+
return self._run_pipeline(
1140+
program,
1141+
fetch_list=fetch_list,
1142+
use_program_cache=use_program_cache)
11391143
if isinstance(program, Program) and \
11401144
len(program.global_block().ops) == 0:
11411145
if use_default_main_program:
@@ -1536,6 +1540,141 @@ def _run_from_dataset(self,
15361540

15371541
return None
15381542

1543+
def _prepare_pipeline_ctx(self,
1544+
program=None,
1545+
dataset=None,
1546+
scope=None,
1547+
thread=0,
1548+
is_infer=False,
1549+
debug=False,
1550+
fetch_list=None,
1551+
fetch_info=None,
1552+
print_period=100,
1553+
fetch_handler=None,
1554+
use_program_cache=False):
1555+
assert program._pipeline_opt is not None
1556+
assert dataset is None, "dataset should be None for pipeline mode"
1557+
1558+
cache_key = _get_strong_program_cache_key(program, None, fetch_list)
1559+
ctx = self._get_ctx_cache(cache_key)
1560+
if use_program_cache and ctx is not None:
1561+
return ctx
1562+
1563+
import paddle
1564+
1565+
# The following fake dataset is created to call
1566+
# the _prepare_trainer api, and it is meaningless.
1567+
def _get_dataset():
1568+
data_vars = []
1569+
for var in program.global_block().vars.values():
1570+
if var.is_data:
1571+
data_vars.append(var)
1572+
if core.is_compiled_with_npu():
1573+
dataset = paddle.fluid.DatasetFactory().create_dataset(
1574+
'InMemoryDataset')
1575+
else:
1576+
dataset = paddle.fluid.DatasetFactory().create_dataset(
1577+
'FileInstantDataset')
1578+
dataset.set_batch_size(1)
1579+
dataset.set_thread(1)
1580+
dataset.set_filelist(['None'])
1581+
dataset.set_use_var(data_vars)
1582+
dataset._prepare_to_run()
1583+
return dataset
1584+
1585+
dataset = _get_dataset()
1586+
1587+
def _get_real_program_fetch_list():
1588+
real_program = program._pipeline_opt["section_program"]
1589+
real_fetch_list = []
1590+
for fetch_var in fetch_list:
1591+
if isinstance(fetch_var, Variable):
1592+
fetch_var_name = fetch_var.name
1593+
else:
1594+
fetch_var_name = fetch_var
1595+
if fetch_var_name in real_program.global_block().vars:
1596+
real_fetch_list.append(fetch_var)
1597+
1598+
real_program = self._add_feed_fetch_ops(
1599+
program=real_program,
1600+
feed=[],
1601+
fetch_list=real_fetch_list,
1602+
feed_var_name='feed',
1603+
fetch_var_name='fetch')
1604+
main_block = real_program.block(0)
1605+
for op in main_block.ops:
1606+
# set the op_role of fetch op to Optimize to avoid
1607+
# erase the fetched vars by gc for pipeline
1608+
if op.type == 'fetch':
1609+
op._set_attr(
1610+
'op_role',
1611+
core.op_proto_and_checker_maker.OpRole.Optimize)
1612+
return real_program, real_fetch_list
1613+
1614+
real_program, real_fetch_list = _get_real_program_fetch_list()
1615+
1616+
program._pipeline_opt["section_program"] = real_program
1617+
fetch_list = None
1618+
1619+
scope, trainer = self._prepare_trainer(
1620+
program=program,
1621+
dataset=dataset,
1622+
scope=scope,
1623+
thread=thread,
1624+
debug=debug,
1625+
fetch_list=fetch_list,
1626+
fetch_info=fetch_info,
1627+
print_period=print_period)
1628+
1629+
trainer._set_infer(is_infer)
1630+
trainer._gen_trainer_desc()
1631+
1632+
# NOTE: only for debug, very slow
1633+
# self._dump_debug_info(program=program, trainer=trainer)
1634+
1635+
# in case of calling _set_use_ps_gpu explicitly
1636+
if dataset.use_ps_gpu is False:
1637+
dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu)
1638+
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
1639+
1640+
trainer_desc = trainer._desc() # slow, cache
1641+
ctx = [trainer_desc, dataset, scope, real_fetch_list]
1642+
if use_program_cache: self._add_ctx_cache(cache_key, ctx)
1643+
return ctx
1644+
1645+
def _run_pipeline(self,
1646+
program=None,
1647+
dataset=None,
1648+
scope=None,
1649+
thread=0,
1650+
is_infer=False,
1651+
debug=False,
1652+
fetch_list=None,
1653+
fetch_info=None,
1654+
print_period=100,
1655+
fetch_handler=None,
1656+
use_program_cache=False):
1657+
trainer_desc, dataset, scope, real_fetch_list = \
1658+
self._prepare_pipeline_ctx(program, dataset, scope, thread,
1659+
is_infer, debug, fetch_list, fetch_info,
1660+
print_period, fetch_handler,
1661+
use_program_cache)
1662+
1663+
trainer_instance = self._default_executor.init_for_dataset(
1664+
program.desc, trainer_desc, scope, dataset.dataset)
1665+
1666+
self._default_executor.run_from_dataset(trainer_instance)
1667+
self._default_executor.release_trainer(trainer_instance)
1668+
1669+
dataset._dynamic_adjust_after_train()
1670+
dataset._finish_to_run()
1671+
if real_fetch_list:
1672+
arr = scope.find_var('fetch').get_fetch_list()
1673+
tensors = arr._move_to_list()
1674+
return as_numpy(tensors)
1675+
1676+
return None
1677+
15391678
def infer_from_dataset(self,
15401679
program=None,
15411680
dataset=None,

0 commit comments

Comments
 (0)