@@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
257257 for d in tensor_send_prev :
258258 paddle .distributed .wait (d , use_calc_stream = True )
259259 send_partial (
260- d . detach () ,
260+ d ,
261261 dst = 0 ,
262262 nranks = mp_degree ,
263263 rank_id = mp_rank ,
@@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
266266 else :
267267 paddle .distributed .wait (tensor_send_prev , use_calc_stream = True )
268268 send_partial (
269- tensor_send_prev . detach () ,
269+ tensor_send_prev ,
270270 dst = 0 ,
271271 nranks = mp_degree ,
272272 rank_id = mp_rank ,
@@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
277277 if isinstance (tensor_recv_prev , tuple ):
278278 for d in tensor_recv_prev :
279279 recv_partial (
280- d . detach () ,
280+ d ,
281281 src = 0 ,
282282 nranks = mp_degree ,
283283 rank_id = mp_rank ,
284284 group = _hcg .recv_prev_group ,
285285 use_calc_stream = True )
286286 allgather_partial (
287- d . detach () ,
287+ d ,
288288 nranks = mp_degree ,
289289 rank_id = mp_rank ,
290290 group = mp_group ,
291291 use_calc_stream = True )
292292 else :
293293 recv_partial (
294- tensor_recv_prev . detach () ,
294+ tensor_recv_prev ,
295295 src = 0 ,
296296 nranks = mp_degree ,
297297 rank_id = mp_rank ,
298298 group = _hcg .recv_prev_group ,
299299 use_calc_stream = True )
300300 allgather_partial (
301- tensor_recv_prev . detach () ,
301+ tensor_recv_prev ,
302302 nranks = mp_degree ,
303303 rank_id = mp_rank ,
304304 group = mp_group ,
@@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
309309 for d in tensor_send_next :
310310 paddle .distributed .wait (d , use_calc_stream = True )
311311 send_partial (
312- d . detach () ,
312+ d ,
313313 dst = 1 ,
314314 nranks = mp_degree ,
315315 rank_id = mp_rank ,
@@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
318318 else :
319319 paddle .distributed .wait (tensor_send_next , use_calc_stream = True )
320320 send_partial (
321- tensor_send_next . detach () ,
321+ tensor_send_next ,
322322 dst = 1 ,
323323 nranks = mp_degree ,
324324 rank_id = mp_rank ,
@@ -329,30 +329,30 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
329329 if isinstance (tensor_recv_next , tuple ):
330330 for d in tensor_recv_next :
331331 recv_partial (
332- d . detach () ,
332+ d ,
333333 src = 1 ,
334334 nranks = mp_degree ,
335335 rank_id = mp_rank ,
336336 group = _hcg .recv_next_group ,
337337 use_calc_stream = True )
338338 allgather_partial (
339- d . detach () ,
339+ d ,
340340 nranks = mp_degree ,
341341 rank_id = mp_rank ,
342342 group = mp_group ,
343343 use_calc_stream = True )
344344
345345 else :
346346 recv_partial (
347- tensor_recv_next . detach () ,
347+ tensor_recv_next ,
348348 src = 1 ,
349349 nranks = mp_degree ,
350350 rank_id = mp_rank ,
351351 group = _hcg .recv_next_group ,
352352 use_calc_stream = True )
353353
354354 allgather_partial (
355- tensor_recv_next . detach () ,
355+ tensor_recv_next ,
356356 nranks = mp_degree ,
357357 rank_id = mp_rank ,
358358 group = mp_group ,
0 commit comments