1313# limitations under the License.
1414
1515import paddle
16- import paddle .distributed as dist
1716
1817_groups = None
1918_hcg = None
2019
2120
2221def initialize_p2p_groups (hcg ):
2322 global _groups , _hcg
24- _groups = [dist .new_group (ranks = group ) for group in hcg .get_p2p_groups ()]
23+ _groups = [
24+ paddle .distributed .new_group (ranks = group )
25+ for group in hcg .get_p2p_groups ()
26+ ]
2527 _hcg = hcg
2628
2729
@@ -33,7 +35,7 @@ def send(tensor, dest_stage):
3335 _is_valid_communciate (src_stage , dest_stage )
3436 group = _get_send_recv_group (src_stage , dest_stage )
3537 dst_rank = _hcg .get_rank_from_stage (stage_id = dest_stage )
36- return dist .broadcast (tensor , src_rank , group = group )
38+ return paddle . distributed .broadcast (tensor , src_rank , group = group )
3739
3840
3941def recv (tensor , src_stage ):
@@ -43,7 +45,7 @@ def recv(tensor, src_stage):
4345 _is_valid_communciate (src_stage , dest_stage )
4446 group = _get_send_recv_group (src_stage , dest_stage )
4547 src_rank = _hcg .get_rank_from_stage (stage_id = src_stage )
46- return dist .broadcast (tensor , src_rank , group = group )
48+ return paddle . distributed .broadcast (tensor , src_rank , group = group )
4749
4850
4951def _is_valid_communciate (src_stage , dest_stage ):
0 commit comments