Skip to content

Commit 6c8a10a

Browse files
authored
rm detach (#34644)
1 parent 6151ccd commit 6c8a10a

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
257257
for d in tensor_send_prev:
258258
paddle.distributed.wait(d, use_calc_stream=True)
259259
send_partial(
260-
d.detach(),
260+
d,
261261
dst=0,
262262
nranks=mp_degree,
263263
rank_id=mp_rank,
@@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
266266
else:
267267
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
268268
send_partial(
269-
tensor_send_prev.detach(),
269+
tensor_send_prev,
270270
dst=0,
271271
nranks=mp_degree,
272272
rank_id=mp_rank,
@@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
277277
if isinstance(tensor_recv_prev, tuple):
278278
for d in tensor_recv_prev:
279279
recv_partial(
280-
d.detach(),
280+
d,
281281
src=0,
282282
nranks=mp_degree,
283283
rank_id=mp_rank,
284284
group=_hcg.recv_prev_group,
285285
use_calc_stream=True)
286286
allgather_partial(
287-
d.detach(),
287+
d,
288288
nranks=mp_degree,
289289
rank_id=mp_rank,
290290
group=mp_group,
291291
use_calc_stream=True)
292292
else:
293293
recv_partial(
294-
tensor_recv_prev.detach(),
294+
tensor_recv_prev,
295295
src=0,
296296
nranks=mp_degree,
297297
rank_id=mp_rank,
298298
group=_hcg.recv_prev_group,
299299
use_calc_stream=True)
300300
allgather_partial(
301-
tensor_recv_prev.detach(),
301+
tensor_recv_prev,
302302
nranks=mp_degree,
303303
rank_id=mp_rank,
304304
group=mp_group,
@@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
309309
for d in tensor_send_next:
310310
paddle.distributed.wait(d, use_calc_stream=True)
311311
send_partial(
312-
d.detach(),
312+
d,
313313
dst=1,
314314
nranks=mp_degree,
315315
rank_id=mp_rank,
@@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
318318
else:
319319
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
320320
send_partial(
321-
tensor_send_next.detach(),
321+
tensor_send_next,
322322
dst=1,
323323
nranks=mp_degree,
324324
rank_id=mp_rank,
@@ -329,30 +329,30 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
329329
if isinstance(tensor_recv_next, tuple):
330330
for d in tensor_recv_next:
331331
recv_partial(
332-
d.detach(),
332+
d,
333333
src=1,
334334
nranks=mp_degree,
335335
rank_id=mp_rank,
336336
group=_hcg.recv_next_group,
337337
use_calc_stream=True)
338338
allgather_partial(
339-
d.detach(),
339+
d,
340340
nranks=mp_degree,
341341
rank_id=mp_rank,
342342
group=mp_group,
343343
use_calc_stream=True)
344344

345345
else:
346346
recv_partial(
347-
tensor_recv_next.detach(),
347+
tensor_recv_next,
348348
src=1,
349349
nranks=mp_degree,
350350
rank_id=mp_rank,
351351
group=_hcg.recv_next_group,
352352
use_calc_stream=True)
353353

354354
allgather_partial(
355-
tensor_recv_next.detach(),
355+
tensor_recv_next,
356356
nranks=mp_degree,
357357
rank_id=mp_rank,
358358
group=mp_group,

0 commit comments

Comments
 (0)