Skip to content

Commit ec6d78c

Browse files
committed
[HybridParallel]Fix precision problem of model parallel (#32897)
* fix precision of mp * fix bug of seed * fix dp * print group
1 parent 7c0b96e commit ec6d78c

File tree

14 files changed

+151
-58
lines changed

14 files changed

+151
-58
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ message PipelineConfig {
141141

142142
message TensorParallelConfig {
143143
optional int32 tensor_parallel_degree = 1 [ default = 1 ];
144+
optional int32 tensor_init_seed = 2 [ default = -1 ];
144145
}
145146

146147
message DistributedStrategy {

python/paddle/distributed/collective.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ def get_group_rank(self, rank):
9898
else:
9999
return -1
100100

101+
def __repr__(self):
102+
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
103+
self.rank, self.nranks, self.id)
104+
debug_str += ", ".join(map(str, self.ranks))
105+
debug_str += ". "
106+
return debug_str
107+
101108

102109
_global_env = None
103110

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,8 @@ def tensor_parallel_configs(self):
923923
**Notes**:
924924
**Detailed arguments for tensor_parallel_configs**
925925
**tensor_parallel_degree**: degree of tensor parallel
926+
**tensor_init_seed**: parameter initialization random seed
927+
926928
927929
Examples:
928930
@@ -931,7 +933,8 @@ def tensor_parallel_configs(self):
931933
import paddle.distributed.fleet as fleet
932934
strategy = fleet.DistributedStrategy()
933935
strategy.tensor_parallel = True
934-
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
936+
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
937+
"tensor_init_seed": 123}
935938
936939
"""
937940
return get_msg_dict(self.strategy.tensor_parallel_configs)

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import warnings
1818
import paddle
1919
import os
20+
import numpy as np
2021
from paddle.fluid.framework import dygraph_only
2122
from paddle.fluid import compiler
2223
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
@@ -28,7 +29,7 @@
2829
from paddle.fluid.dygraph import parallel_helper
2930
from . import topology as tp
3031
from .topology import ParallelMode
31-
from ..meta_parallel import ModelParallel
32+
from ..meta_parallel import TensorParallel, model_parallel_random_seed
3233
from ..meta_parallel import PipelineParallel
3334
from ..meta_optimizers import HybridParallelOptimizer
3435
from ..meta_optimizers import HybridParallelGradScaler
@@ -279,6 +280,14 @@ def _init_hybrid_parallel_env(self):
279280

280281
self._hcg = tp.HybridCommunicateGroup(self._topology)
281282

283+
if self.mp_degree > 1:
284+
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
285+
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
286+
if tensor_init_seed == -1:
287+
model_parallel_random_seed()
288+
else:
289+
model_parallel_random_seed(tensor_init_seed)
290+
282291
def get_hybrid_communicate_group(self):
283292
assert self._hcg is not None
284293
return self._hcg
@@ -780,8 +789,8 @@ def forward(self, x):
780789
last_comm_group_size_MB,
781790
find_unused_parameters=self._user_defined_strategy.
782791
find_unused_parameters)
783-
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
784-
distributed_model = ModelParallel(
792+
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
793+
distributed_model = TensorParallel(
785794
model, self._hcg, strategy=self._user_defined_strategy)
786795
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
787796
distributed_model = PipelineParallel(

python/paddle/distributed/fleet/base/topology.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
class ParallelMode(object):
3030
DATA_PARALLEL = 0
31-
MODEL_PARALLEL = 1
31+
TENSOR_PARALLEL = 1
3232
PIPELINE_PARALLEL = 2
3333

3434

@@ -155,12 +155,12 @@ def __init__(self, topology):
155155
_HYBRID_PARALLEL_GROUP = self
156156

157157
def get_parallel_mode(self):
158-
# there are three modes : DataParallel / ModelParallel / PipelineParallel
158+
# there are three modes : DataParallel / TensorParallel / PipelineParallel
159159
if self._mp_degree == 1 and self._pp_degree == 1:
160160
return ParallelMode.DATA_PARALLEL
161161
elif self._mp_degree > 1 and self._pp_degree == 1:
162162
# initialize the seed
163-
return ParallelMode.MODEL_PARALLEL
163+
return ParallelMode.TENSOR_PARALLEL
164164
elif self._pp_degree > 1:
165165
return ParallelMode.PIPELINE_PARALLEL
166166

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, scaler, hcg):
3131
self._scaler = scaler
3232
self._hcg = hcg
3333
self._is_mp = (
34-
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
34+
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
3535

3636
def scale(self, var):
3737
return self._scaler.scale(var)

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def __init__(self, optimizer, hcg, strategy):
9090
self._strategy = strategy
9191
self._hcg = hcg
9292
self._is_mp = (
93-
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
93+
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
9494
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
9595

9696
if isinstance(self._inner_opt._grad_clip,
9797
ClipGradByGlobalNorm) and self._is_mp:
98-
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \
98+
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
9999
"optmizer'grad clip will be changed.")
100100
self._inner_opt._grad_clip = HybridParallelClipGrad(
101101
self._inner_opt._grad_clip, hcg)

python/paddle/distributed/fleet/meta_parallel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .parallel_layers import RNGStatesTracker # noqa: F401
2121
from .parallel_layers import model_parallel_random_seed # noqa: F401
2222
from .parallel_layers import get_rng_state_tracker # noqa: F401
23-
from .model_parallel import ModelParallel # noqa: F401
23+
from .tensor_parallel import TensorParallel # noqa: F401
2424
from .pipeline_parallel import PipelineParallel # noqa: F401
2525

2626
__all__ = []

python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self,
4141
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
4242

4343
self.origin_num_embeddings = num_embeddings
44+
self.is_mp = (self.world_size > 1)
4445

4546
per_part_size = (
4647
num_embeddings + self.world_size - 1) // self.world_size
@@ -50,16 +51,36 @@ def __init__(self,
5051
per_part_size += 1 # make the last row as the padding index
5152
self.per_part_size = per_part_size
5253

53-
self.embedding = paddle.nn.Embedding(
54-
per_part_size,
55-
embedding_dim,
56-
padding_idx=per_part_size - 1,
57-
sparse=False,
58-
weight_attr=weight_attr,
59-
name=name)
60-
self.embedding.weight.is_distributed = True
54+
self._dtype = self._helper.get_default_dtype()
55+
self._size = [per_part_size, embedding_dim]
56+
self._weight_attr = weight_attr
57+
self._name = name
58+
59+
if self.is_mp:
60+
with get_rng_state_tracker().rng_state():
61+
self.weight = self.create_parameter(
62+
attr=self._weight_attr,
63+
shape=self._size,
64+
dtype=self._dtype,
65+
is_bias=False)
66+
self.weight[per_part_size - 1] = 0.0
67+
self.weight.is_distributed = True
68+
else:
69+
self.weight = self.create_parameter(
70+
attr=self._weight_attr,
71+
shape=[num_embeddings, embedding_dim],
72+
dtype=self._dtype,
73+
is_bias=False)
6174

6275
def forward(self, x):
76+
if not self.is_mp:
77+
return F.embedding(
78+
x,
79+
weight=self.weight,
80+
padding_idx=None,
81+
sparse=False,
82+
name=self._name)
83+
6384
origin_input_shape = x.shape
6485
if len(origin_input_shape) == 2:
6586
x = paddle.unsqueeze(x, axis=-1)
@@ -72,13 +93,18 @@ def forward(self, x):
7293
if len(origin_input_shape) == 2:
7394
x_shard = paddle.squeeze(x_shard, axis=-1)
7495

75-
emb_out = self.embedding(x_shard)
76-
if self.world_size > 1:
77-
emb_out = paddle.distributed.collective._mp_allreduce(
78-
emb_out,
79-
group=self.model_parallel_group,
80-
use_calc_stream=True,
81-
use_model_parallel=True)
96+
emb_out = F.embedding(
97+
x_shard,
98+
weight=self.weight,
99+
padding_idx=self.per_part_size - 1,
100+
sparse=False,
101+
name=self._name)
102+
103+
emb_out = paddle.distributed.collective._mp_allreduce(
104+
emb_out,
105+
group=self.model_parallel_group,
106+
use_calc_stream=True,
107+
use_model_parallel=True)
82108
return emb_out
83109

84110

@@ -96,8 +122,9 @@ def __init__(self,
96122
)
97123
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
98124
)
125+
self._name = name
126+
self.is_mp = (self.world_size > 1)
99127

100-
self.name = name
101128
self.gather_output = gather_output
102129
assert out_features % self.world_size == 0, (
103130
"Number of column of the weight for linear ({}) must be"
@@ -108,29 +135,45 @@ def __init__(self,
108135
self._weight_attr = weight_attr
109136
self._dtype = self._helper.get_default_dtype()
110137

111-
self.weight = self.create_parameter(
112-
shape=[in_features, self.output_size_per_partition],
113-
attr=self._weight_attr,
114-
dtype=self._dtype)
138+
if self.is_mp:
139+
with get_rng_state_tracker().rng_state():
140+
self.weight = self.create_parameter(
141+
shape=[in_features, self.output_size_per_partition],
142+
attr=self._weight_attr,
143+
dtype=self._dtype,
144+
is_bias=False)
145+
else:
146+
self.weight = self.create_parameter(
147+
shape=[in_features, self.output_size_per_partition],
148+
attr=self._weight_attr,
149+
dtype=self._dtype,
150+
is_bias=False)
151+
115152
self.weight.is_distributed = True
116153

117154
if has_bias:
118155
# initialize bias to zero like Megatron
119156
self.bias = self.create_parameter(
120157
shape=[self.output_size_per_partition],
121158
attr=paddle.nn.initializer.Constant(value=0.0),
122-
dtype=self._dtype)
159+
dtype=self._dtype,
160+
is_bias=True)
123161
self.bias.is_distributed = True
124162
else:
125163
self.bias = None
126164

127165
def forward(self, x):
128166
# use inner api to process identity
129-
input_parallel = paddle.distributed.collective._c_identity(
130-
x, group=self.model_parallel_group)
167+
if self.is_mp:
168+
input_parallel = paddle.distributed.collective._c_identity(
169+
x, group=self.model_parallel_group)
170+
else:
171+
input_parallel = x
172+
131173
output_parallel = F.linear(
132-
input_parallel, self.weight, self.bias, name=self.name)
133-
if self.gather_output:
174+
input_parallel, self.weight, self.bias, name=self._name)
175+
176+
if self.gather_output and self.is_mp:
134177
output = paddle.distributed.collective._c_concat(
135178
output_parallel,
136179
nranks=self.world_size,
@@ -155,37 +198,49 @@ def __init__(self,
155198
self.input_is_parallel = input_is_parallel
156199
self._weight_attr = weight_attr
157200
self._dtype = self._helper.get_default_dtype()
158-
self.name = name
201+
self._name = name
159202

160203
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
161204
)
162205
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
163206
)
164207
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
165208

209+
self.is_mp = (self.world_size > 1)
166210
assert in_features % self.world_size == 0, (
167211
"Number of row of the weight for linear ({}) must be"
168212
" divisible by model parallel size ({})".format(in_features,
169213
self.world_size))
170214

171215
self.input_size_per_partition = in_features // self.world_size
172216

173-
self.weight = self.create_parameter(
174-
shape=[self.input_size_per_partition, self.out_features],
175-
attr=self._weight_attr,
176-
dtype=self._dtype)
217+
if self.is_mp:
218+
with get_rng_state_tracker().rng_state():
219+
self.weight = self.create_parameter(
220+
shape=[self.input_size_per_partition, self.out_features],
221+
attr=self._weight_attr,
222+
dtype=self._dtype,
223+
is_bias=False)
224+
else:
225+
self.weight = self.create_parameter(
226+
shape=[self.input_size_per_partition, self.out_features],
227+
attr=self._weight_attr,
228+
dtype=self._dtype,
229+
is_bias=False)
230+
177231
self.weight.is_distributed = True
178232

179233
if has_bias:
180234
self.bias = self.create_parameter(
181235
shape=[self.out_features],
182236
attr=paddle.nn.initializer.Constant(value=0.0),
183-
dtype=self._dtype)
237+
dtype=self._dtype,
238+
is_bias=True)
184239
else:
185240
self.bias = None
186241

187242
def forward(self, x):
188-
if self.input_is_parallel:
243+
if self.input_is_parallel or (not self.is_mp):
189244
input_parallel = x
190245
else:
191246
# split last dim
@@ -195,12 +250,16 @@ def forward(self, x):
195250
nranks=self.world_size,
196251
group=self.model_parallel_group)
197252

198-
output_parallel = F.linear(input_parallel, self.weight, name=self.name)
199-
output_ = paddle.distributed.collective._mp_allreduce(
200-
output_parallel,
201-
group=self.model_parallel_group,
202-
use_calc_stream=True,
203-
use_model_parallel=True)
253+
output_parallel = F.linear(input_parallel, self.weight, name=self._name)
254+
255+
if self.is_mp:
256+
output_ = paddle.distributed.collective._mp_allreduce(
257+
output_parallel,
258+
group=self.model_parallel_group,
259+
use_calc_stream=True,
260+
use_model_parallel=True)
261+
else:
262+
output_ = output_parallel
204263

205264
output = output_ + self.bias if self.bias is not None else output_
206265
return output

python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import paddle
1616
import contextlib
17+
import numpy as np
1718

1819
__all__ = []
1920

@@ -65,14 +66,18 @@ def get_rng_state_tracker():
6566
return RNG_STATE_TRACKER
6667

6768

68-
def model_parallel_random_seed(seed=2048):
69+
def model_parallel_random_seed(seed=None):
6970
import paddle.distributed.fleet as fleet
7071
hcg = fleet.get_hybrid_communicate_group()
7172
rank = hcg.get_model_parallel_rank()
7273

73-
local_seed = seed + 1024 + rank
74-
global_seed = seed
74+
if seed:
75+
global_seed = seed
76+
local_seed = seed * 1024 + rank * 100
77+
else:
78+
global_seed = np.random.randint(0, 655350)
79+
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
7580

7681
RNG_STATE_TRACKER.reset()
77-
paddle.seed(global_seed)
7882
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
83+
paddle.seed(global_seed)

0 commit comments

Comments
 (0)