@@ -4865,9 +4865,9 @@ def _insert_send_recv(cur_id, prev_id):
48654865 var_shape [0 ] = self .micro_batch_size if var_shape [
48664866 0 ] < 0 else var_shape [0 ]
48674867
4868- numel = np .prod (var . shape )
4869- assert numel % self .mp_degree == 0 , \
4870- "The numel={} must be divisible by mp_degree={}" . format ( numel , self .mp_degree )
4868+ numel = np .prod (var_shape )
4869+ use_mp = ( self .mp_degree > 1 ) and (
4870+ numel % self .mp_degree == 0 )
48714871
48724872 if 'subprog' in var .name :
48734873 # For recompute, if the checkpoints var is layer_norm_6.tmp_2
@@ -4906,8 +4906,7 @@ def _insert_send_recv(cur_id, prev_id):
49064906 extra_index_info ['index' ] += 1
49074907 block ._insert_op_without_sync (
49084908 index = index + extra_index_info ['index' ],
4909- type = 'send_v2'
4910- if self .mp_degree == 1 else 'partial_send' ,
4909+ type = 'send_v2' if not use_mp else 'partial_send' ,
49114910 inputs = {'X' : var },
49124911 attrs = {
49134912 self ._op_device_key : prev_dev ,
@@ -4943,8 +4942,7 @@ def _insert_send_recv(cur_id, prev_id):
49434942 extra_index_info ['index' ] += 1
49444943 block ._insert_op_without_sync (
49454944 index = index + extra_index_info ['index' ],
4946- type = 'recv_v2'
4947- if self .mp_degree == 1 else 'partial_recv' ,
4945+ type = 'recv_v2' if not use_mp else 'partial_recv' ,
49484946 outputs = {'Out' : [var ]},
49494947 attrs = {
49504948 'out_shape' : var_shape ,
@@ -4959,7 +4957,7 @@ def _insert_send_recv(cur_id, prev_id):
49594957 'id' : self .mp_rank ,
49604958 })
49614959 extra_index_info ['index' ] += 1
4962- if self . mp_degree > 1 :
4960+ if use_mp :
49634961 block ._insert_op_without_sync (
49644962 index = index + extra_index_info ['index' ],
49654963 type = 'partial_allgather' ,
0 commit comments