Skip to content

Commit 73687d8

Browse files
committed
update the api
1 parent 53ff99a commit 73687d8

2 files changed

Lines changed: 17 additions & 7 deletions

File tree

python/paddle/distributed/auto_parallel/api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from paddle.framework import core
4646

47-
from .placement_type import check_placements_equal, get_shard_spec, to_dim_map
47+
from .placement_type import check_placements_equal, get_shard_spec
4848
from .random import determinate_rng, rng_state
4949

5050
# There are the auto parallel API of the unified version of dynamic and static mode.
@@ -190,11 +190,10 @@ def lazy_init_hook(param, origin_hook):
190190
# lazy init hook with randomness controlling
191191
def _init_func(var, block):
192192
# get the unique rng name
193-
dims_mapping = to_dim_map(
194-
param.placements, len(param.shape)
195-
)
196193
rng_name = determinate_rng(
197-
dist.get_rank(), dims_mapping, param.process_mesh
194+
dist.get_rank(),
195+
process_mesh=param.process_mesh,
196+
placements=param.placements,
198197
)
199198
# real call the init function
200199
with rng_state(rng_name):

python/paddle/distributed/auto_parallel/random.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,16 @@ def parallel_manual_seed(seed, name=""):
7373
_basic_name = name
7474

7575

76-
def determinate_rng(rank, dims_mapping, process_mesh):
76+
def determinate_rng(
77+
rank, dims_mapping=None, process_mesh=None, placements=None
78+
):
79+
assert process_mesh is not None, "Must provide process mesh"
80+
assert (
81+
dims_mapping is not None or placements is not None
82+
), "Must provide one of dims mapping or placements."
83+
assert not (
84+
dims_mapping is not None and placements is not None
85+
), "Cannot provide dims mapping and placements at same time."
7786
# TODO(JZ-LIANG) Support Mesh with any high rank
7887
# use a string to unique integer hashing algorithm for seed computation.
7988
# instead of using offsets to coodinate seed across devices.
@@ -100,7 +109,9 @@ def determinate_rng(rank, dims_mapping, process_mesh):
100109
seed_ += _mesh_offset * (unique_id + 1)
101110

102111
for i in range(len(process_mesh.shape)):
103-
if i not in dims_mapping:
112+
if (dims_mapping is not None and i not in dims_mapping) or (
113+
placements is not None and not placements[i].is_shard()
114+
):
104115
relative_idx = -1
105116
else:
106117
relative_idx = _get_idx_in_axis(

0 commit comments

Comments
 (0)