@@ -1136,6 +1136,10 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
11361136 program = program ._pipeline_opt ["startup_program" ]
11371137 else :
11381138 return self .train_from_dataset (program , fetch_list = fetch_list )
1139+ return self ._run_pipeline (
1140+ program ,
1141+ fetch_list = fetch_list ,
1142+ use_program_cache = use_program_cache )
11391143 if isinstance (program , Program ) and \
11401144 len (program .global_block ().ops ) == 0 :
11411145 if use_default_main_program :
@@ -1536,6 +1540,141 @@ def _run_from_dataset(self,
15361540
15371541 return None
15381542
1543+ def _prepare_pipeline_ctx (self ,
1544+ program = None ,
1545+ dataset = None ,
1546+ scope = None ,
1547+ thread = 0 ,
1548+ is_infer = False ,
1549+ debug = False ,
1550+ fetch_list = None ,
1551+ fetch_info = None ,
1552+ print_period = 100 ,
1553+ fetch_handler = None ,
1554+ use_program_cache = False ):
1555+ assert program ._pipeline_opt is not None
1556+ assert dataset is None , "dataset should be None for pipeline mode"
1557+
1558+ cache_key = _get_strong_program_cache_key (program , None , fetch_list )
1559+ ctx = self ._get_ctx_cache (cache_key )
1560+ if use_program_cache and ctx is not None :
1561+ return ctx
1562+
1563+ import paddle
1564+
1565+ # The following fake dataset is created to call
1566+ # the _prepare_trainer api, and it is meaningless.
1567+ def _get_dataset ():
1568+ data_vars = []
1569+ for var in program .global_block ().vars .values ():
1570+ if var .is_data :
1571+ data_vars .append (var )
1572+ if core .is_compiled_with_npu ():
1573+ dataset = paddle .fluid .DatasetFactory ().create_dataset (
1574+ 'InMemoryDataset' )
1575+ else :
1576+ dataset = paddle .fluid .DatasetFactory ().create_dataset (
1577+ 'FileInstantDataset' )
1578+ dataset .set_batch_size (1 )
1579+ dataset .set_thread (1 )
1580+ dataset .set_filelist (['None' ])
1581+ dataset .set_use_var (data_vars )
1582+ dataset ._prepare_to_run ()
1583+ return dataset
1584+
1585+ dataset = _get_dataset ()
1586+
1587+ def _get_real_program_fetch_list ():
1588+ real_program = program ._pipeline_opt ["section_program" ]
1589+ real_fetch_list = []
1590+ for fetch_var in fetch_list :
1591+ if isinstance (fetch_var , Variable ):
1592+ fetch_var_name = fetch_var .name
1593+ else :
1594+ fetch_var_name = fetch_var
1595+ if fetch_var_name in real_program .global_block ().vars :
1596+ real_fetch_list .append (fetch_var )
1597+
1598+ real_program = self ._add_feed_fetch_ops (
1599+ program = real_program ,
1600+ feed = [],
1601+ fetch_list = real_fetch_list ,
1602+ feed_var_name = 'feed' ,
1603+ fetch_var_name = 'fetch' )
1604+ main_block = real_program .block (0 )
1605+ for op in main_block .ops :
1606+ # set the op_role of fetch op to Optimize to avoid
1607+ # erase the fetched vars by gc for pipeline
1608+ if op .type == 'fetch' :
1609+ op ._set_attr (
1610+ 'op_role' ,
1611+ core .op_proto_and_checker_maker .OpRole .Optimize )
1612+ return real_program , real_fetch_list
1613+
1614+ real_program , real_fetch_list = _get_real_program_fetch_list ()
1615+
1616+ program ._pipeline_opt ["section_program" ] = real_program
1617+ fetch_list = None
1618+
1619+ scope , trainer = self ._prepare_trainer (
1620+ program = program ,
1621+ dataset = dataset ,
1622+ scope = scope ,
1623+ thread = thread ,
1624+ debug = debug ,
1625+ fetch_list = fetch_list ,
1626+ fetch_info = fetch_info ,
1627+ print_period = print_period )
1628+
1629+ trainer ._set_infer (is_infer )
1630+ trainer ._gen_trainer_desc ()
1631+
1632+ # NOTE: only for debug, very slow
1633+ # self._dump_debug_info(program=program, trainer=trainer)
1634+
1635+ # in case of calling _set_use_ps_gpu explicitly
1636+ if dataset .use_ps_gpu is False :
1637+ dataset ._set_use_ps_gpu (trainer .proto_desc .use_ps_gpu )
1638+ dataset ._dynamic_adjust_before_train (trainer .proto_desc .thread_num )
1639+
1640+ trainer_desc = trainer ._desc () # slow, cache
1641+ ctx = [trainer_desc , dataset , scope , real_fetch_list ]
1642+ if use_program_cache : self ._add_ctx_cache (cache_key , ctx )
1643+ return ctx
1644+
1645+ def _run_pipeline (self ,
1646+ program = None ,
1647+ dataset = None ,
1648+ scope = None ,
1649+ thread = 0 ,
1650+ is_infer = False ,
1651+ debug = False ,
1652+ fetch_list = None ,
1653+ fetch_info = None ,
1654+ print_period = 100 ,
1655+ fetch_handler = None ,
1656+ use_program_cache = False ):
1657+ trainer_desc , dataset , scope , real_fetch_list = \
1658+ self ._prepare_pipeline_ctx (program , dataset , scope , thread ,
1659+ is_infer , debug , fetch_list , fetch_info ,
1660+ print_period , fetch_handler ,
1661+ use_program_cache )
1662+
1663+ trainer_instance = self ._default_executor .init_for_dataset (
1664+ program .desc , trainer_desc , scope , dataset .dataset )
1665+
1666+ self ._default_executor .run_from_dataset (trainer_instance )
1667+ self ._default_executor .release_trainer (trainer_instance )
1668+
1669+ dataset ._dynamic_adjust_after_train ()
1670+ dataset ._finish_to_run ()
1671+ if real_fetch_list :
1672+ arr = scope .find_var ('fetch' ).get_fetch_list ()
1673+ tensors = arr ._move_to_list ()
1674+ return as_numpy (tensors )
1675+
1676+ return None
1677+
15391678 def infer_from_dataset (self ,
15401679 program = None ,
15411680 dataset = None ,
0 commit comments