@@ -27,6 +27,7 @@ class HybridParallelInferenceHelper(object):
2727 Args:
2828 startup_program (Program): the startup program.
2929 main_program (Program): the main program.
30+ num_mp (int): number of model parallel degree. Default ``1``.
3031 num_pp (int): number of pipeline parallel degree. Default ``1``.
3132 micro_batch_size (int): number of micro batch size. Default ``1``.
3233 beam_size (int): number of beam search size. Default ``1``.
@@ -38,6 +39,9 @@ class HybridParallelInferenceHelper(object):
3839 None.
3940
4041 Write Paradigm:
42+
43+ .. code-block:: bash
44+ :name: bash-example1
4145
4246 # while op pattern
4347 with paddle.fluid.device_guard(f'{device}:all'):
@@ -180,6 +184,7 @@ class HybridParallelInferenceHelper(object):
180184 def __init__ (self ,
181185 startup_program ,
182186 main_program ,
187+ num_mp = 1 ,
183188 num_pp = 1 ,
184189 micro_batch_size = 1 ,
185190 beam_size = 1 ,
@@ -234,9 +239,9 @@ def __init__(self,
234239 self .current_endpoint = self .endpoints [self .role_maker ._worker_index ()]
235240 self .rank = self .role_maker ._worker_index ()
236241 self .nranks = self .role_maker ._worker_num ()
237- assert self . nranks % num_pp == 0
242+ assert num_mp * num_pp == self . nranks
238243 self .num_pp = num_pp
239- self .num_mp = self . nranks // self . num_pp
244+ self .num_mp = num_mp
240245
241246 # global ring info
242247 self .global_endpoints = self .endpoints
@@ -398,24 +403,24 @@ def _split_program(self, program, stage, block_idx):
398403
399404 return used_var_names
400405
401- def _find_post_op (self , index , var_name ):
402- """
403- Find the post op that has variable named var_name as input.
404- """
405- # bugfix for uniform hybrid parallelism
406- if '.cast_fp32' in var_name :
407- var_name = var_name .replace ('.cast_fp32' , '' )
408- if '.cast_fp16' in var_name :
409- var_name = var_name .replace ('.cast_fp16' , '' )
410-
411- post_ops = self ._input_var_to_op [var_name ]
412- if post_ops == None : return None
413- result_op = None
414- for post_op , post_idx in reversed (post_ops ):
415- if post_idx > index :
416- result_op = post_op
417- break
418- return result_op
406+ # def _find_post_op(self, index, var_name):
407+ # """
408+ # Find the post op that has variable named var_name as input.
409+ # """
410+ # # bugfix for uniform hybrid parallelism
411+ # if '.cast_fp32' in var_name:
412+ # var_name = var_name.replace('.cast_fp32', '')
413+ # if '.cast_fp16' in var_name:
414+ # var_name = var_name.replace('.cast_fp16', '')
415+
416+ # post_ops = self._input_var_to_op[var_name]
417+ # if post_ops == None: return None
418+ # result_op = None
419+ # for post_op, post_idx in reversed(post_ops):
420+ # if post_idx > index:
421+ # result_op = post_op
422+ # break
423+ # return result_op
419424
420425 def _find_prev_op (self , index , var_name ):
421426 """
@@ -757,7 +762,7 @@ def gen_infer_program(self,
757762 while_block , sync_in_while_lastpp2firstpp_var_names ,
758763 sync_in_while_var_names , self ._stage )
759764
760- # step3: split programs
765+ # step3: split programs
761766 self ._split_program (self ._startup_program , self ._stage , 0 )
762767 self ._split_program (self ._main_program , self ._stage , 0 )
763768
0 commit comments