Skip to content

Commit 63fff5d

Browse files
committed
fix unittest
1 parent 957ab13 commit 63fff5d

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class HybridParallelInferenceHelper(object):
2727
Args:
2828
startup_program (Program): the startup program.
2929
main_program (Program): the main program.
30+
num_mp (int): number of model parallel degree. Default ``1``.
3031
num_pp (int): number of pipeline parallel degree. Default ``1``.
3132
micro_batch_size (int): number of micro batch size. Default ``1``.
3233
beam_size (int): number of beam search size. Default ``1``.
@@ -38,6 +39,9 @@ class HybridParallelInferenceHelper(object):
3839
None.
3940
4041
Write Paradigm:
42+
43+
.. code-block:: bash
44+
:name: bash-example1
4145
4246
# while op pattern
4347
with paddle.fluid.device_guard(f'{device}:all'):
@@ -180,6 +184,7 @@ class HybridParallelInferenceHelper(object):
180184
def __init__(self,
181185
startup_program,
182186
main_program,
187+
num_mp=1,
183188
num_pp=1,
184189
micro_batch_size=1,
185190
beam_size=1,
@@ -234,9 +239,9 @@ def __init__(self,
234239
self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
235240
self.rank = self.role_maker._worker_index()
236241
self.nranks = self.role_maker._worker_num()
237-
assert self.nranks % num_pp == 0
242+
assert num_mp * num_pp == self.nranks
238243
self.num_pp = num_pp
239-
self.num_mp = self.nranks // self.num_pp
244+
self.num_mp = num_mp
240245

241246
# global ring info
242247
self.global_endpoints = self.endpoints
@@ -398,24 +403,24 @@ def _split_program(self, program, stage, block_idx):
398403

399404
return used_var_names
400405

401-
def _find_post_op(self, index, var_name):
402-
"""
403-
Find the post op that has variable named var_name as input.
404-
"""
405-
# bugfix for uniform hybrid parallelism
406-
if '.cast_fp32' in var_name:
407-
var_name = var_name.replace('.cast_fp32', '')
408-
if '.cast_fp16' in var_name:
409-
var_name = var_name.replace('.cast_fp16', '')
410-
411-
post_ops = self._input_var_to_op[var_name]
412-
if post_ops == None: return None
413-
result_op = None
414-
for post_op, post_idx in reversed(post_ops):
415-
if post_idx > index:
416-
result_op = post_op
417-
break
418-
return result_op
406+
# def _find_post_op(self, index, var_name):
407+
# """
408+
# Find the post op that has variable named var_name as input.
409+
# """
410+
# # bugfix for uniform hybrid parallelism
411+
# if '.cast_fp32' in var_name:
412+
# var_name = var_name.replace('.cast_fp32', '')
413+
# if '.cast_fp16' in var_name:
414+
# var_name = var_name.replace('.cast_fp16', '')
415+
416+
# post_ops = self._input_var_to_op[var_name]
417+
# if post_ops == None: return None
418+
# result_op = None
419+
# for post_op, post_idx in reversed(post_ops):
420+
# if post_idx > index:
421+
# result_op = post_op
422+
# break
423+
# return result_op
419424

420425
def _find_prev_op(self, index, var_name):
421426
"""
@@ -757,7 +762,7 @@ def gen_infer_program(self,
757762
while_block, sync_in_while_lastpp2firstpp_var_names,
758763
sync_in_while_var_names, self._stage)
759764

760-
# step3: split programs
765+
# step3: split programs
761766
self._split_program(self._startup_program, self._stage, 0)
762767
self._split_program(self._main_program, self._stage, 0)
763768

python/paddle/fluid/tests/unittests/hybrid_parallel_inference_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ def test_hybrid_parallel_inference_helper_mp1pp2(self):
140140
startup_program,
141141
main_program,
142142
micro_batch_size=2,
143+
num_mp=1,
143144
num_pp=2,
144-
init_comm=nranks > 1)
145-
helper.gen_infer_program(['array_write_0.out'], ['cond_int.tmp_0'])
145+
init_comm=nranks > 1, )
146+
helper.gen_infer_program(
147+
['array_write_0.out'], ['cond_int.tmp_0'], debug=True)
146148

147149
exe = paddle.static.Executor(paddle.CUDAPlace(dev_id))
148150
exe.run(startup_program)

0 commit comments

Comments
 (0)