Skip to content

Commit 42c1297

Browse files
authored
[HybridParallel] update collective split to use c_embedding and mp_allreduce (#33411)
1 parent 9cda9ec commit 42c1297

File tree

5 files changed

+84
-136
lines changed

5 files changed

+84
-136
lines changed

python/paddle/distributed/collective.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,25 @@ def _mp_allreduce(tensor,
894894
"use_model_parallel", use_model_parallel)
895895
else:
896896
raise ValueError("Unknown parameter: {}.".format(op))
897-
else:
898-
raise NotImplementedError("No support _mp_allreduce in dygraph mode.")
897+
898+
op_type = 'c_allreduce_sum'
899+
helper = LayerHelper(op_type, **locals())
900+
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
901+
902+
check_variable_and_dtype(
903+
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
904+
op_type)
905+
906+
helper.append_op(
907+
type=op_type,
908+
inputs={'X': tensor},
909+
outputs={'Out': out},
910+
attrs={
911+
'ring_id': ring_id,
912+
'use_calc_stream': use_calc_stream,
913+
'use_model_parallel': use_model_parallel,
914+
})
915+
return out
899916

900917

901918
def _c_lookup_table(table, index, start_index=0, name=None):
@@ -915,6 +932,19 @@ def _c_lookup_table(table, index, start_index=0, name=None):
915932
if in_dygraph_mode():
916933
return core.ops.c_embedding(table, index, "start_index", start_index)
917934

935+
op_type = 'c_embedding'
936+
helper = LayerHelper(op_type, **locals())
937+
dtype = helper.input_dtype(input_param_name='table')
938+
check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
939+
tmp = helper.create_variable_for_type_inference(dtype)
940+
helper.append_op(
941+
type='c_embedding',
942+
inputs={'Ids': index,
943+
'W': table},
944+
outputs={'Out': tmp},
945+
attrs={"start_index": start_index})
946+
return tmp
947+
918948

919949
class _Linear(layers.Layer):
920950
"""
@@ -1136,47 +1166,34 @@ def _parallel_embedding(x,
11361166
return
11371167
ring_id = 0 if group is None else group.id
11381168

1139-
origin_num_embeddings = origin_size[0]
1140-
embedding = paddle.nn.Embedding(
1141-
per_part_embeddings,
1142-
origin_size[1],
1143-
padding_idx=per_part_embeddings - 1,
1144-
sparse=False,
1145-
weight_attr=param_attr,
1146-
name=name)
1147-
1148-
origin_input_shape = x.shape
1149-
if len(origin_input_shape) == 2:
1150-
x = paddle.unsqueeze(x, axis=-1)
1151-
else:
1152-
assert origin_input_shape[-1] == 1, (
1153-
"The last dimension size of x must be 1.")
1154-
x_shard = paddle.shard_index(x, origin_num_embeddings, num_partitions,
1155-
inner_rank, per_part_embeddings - 1)
1156-
if len(origin_input_shape) == 2:
1157-
x_shard = paddle.squeeze(x_shard, axis=-1)
1158-
emb_out = embedding(x_shard)
1169+
helper = LayerHelper("_parallel_embedding", **locals())
1170+
1171+
per_part_size = per_part_embeddings
1172+
rank = inner_rank
1173+
1174+
vocab_start_index = rank * per_part_size
1175+
dtype = helper.get_default_dtype()
1176+
size = [per_part_size, origin_size[1]]
1177+
1178+
weight = helper.create_parameter(
1179+
attr=param_attr, shape=size, dtype=dtype, is_bias=False)
1180+
1181+
if num_partitions == 1:
1182+
return paddle.nn.functional.embedding(
1183+
x, weight=weight, padding_idx=None, sparse=False, name=name)
1184+
11591185
startup_block = paddle.static.default_startup_program().global_block()
11601186
main_block = paddle.static.default_main_program().global_block()
1161-
startup_block.vars[embedding.weight.name].is_distributed = True
1162-
main_block.vars[embedding.weight.name].is_distributed = True
1163-
out = main_block.create_var(
1164-
shape=emb_out.shape,
1165-
dtype=emb_out.dtype,
1166-
type=emb_out.type,
1167-
lod_level=emb_out.lod_level,
1168-
persistable=False,
1169-
is_data=False,
1170-
need_check_feed=emb_out.desc.need_check_feed())
1171-
main_block.append_op(
1172-
type='c_allreduce_sum',
1173-
inputs={'X': emb_out},
1174-
outputs={'Out': out},
1175-
attrs={
1176-
'ring_id': ring_id,
1177-
'use_calc_stream': True,
1178-
'use_model_parallel': True
1179-
})
1187+
startup_block.vars[weight.name].is_distributed = True
1188+
main_block.vars[weight.name].is_distributed = True
1189+
1190+
output_parallel = paddle.distributed.collective._c_lookup_table(
1191+
weight, x, start_index=vocab_start_index, name=name)
1192+
out = paddle.distributed.collective._mp_allreduce(
1193+
output_parallel,
1194+
group=group,
1195+
use_calc_stream=True,
1196+
use_model_parallel=True)
11801197
return out
11811198

11821199

@@ -1288,11 +1305,11 @@ def split(x,
12881305
if operation == "embedding":
12891306
assert axis == 0, ("We only support to split the weight of embedding "
12901307
"along the first axis now.")
1291-
per_part_size = (size[0] + num_partitions - 1) // num_partitions
1292-
last_part_size = size[0] - per_part_size * (num_partitions - 1)
1293-
if inner_rank == num_partitions - 1: per_part_size = last_part_size
1294-
per_part_size += 1 # make the last row as the padding index
1308+
assert size[0] % num_partitions == 0, \
1309+
"The length of the vocabulary must be divisible by num_partitions " \
1310+
"but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
12951311

1312+
per_part_size = size[0] // num_partitions
12961313
emb_out = _parallel_embedding(
12971314
x,
12981315
per_part_size,

python/paddle/fluid/tests/unittests/parallel_embedding_api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,27 @@ def get_model(self, main_prog, startup_program, rank):
4848
with fluid.program_guard(main_prog, startup_program):
4949
fleet.init(is_collective=True)
5050
np.random.seed(2020)
51-
np_array = np.random.rand(10, 8)
51+
# (num_embeddings, embedding_dim) = (12, 8)
52+
size = (12, 8)
53+
np_array = np.random.rand(size[0], size[1])
5254
paddle.seed(2020)
53-
data_in = paddle.randint(0, 8, shape=(10, 4))
55+
data_in = paddle.randint(0, size[0], shape=(10, 4))
5456

5557
data = paddle.static.data(
5658
name='tindata', shape=[10, 1000], dtype="float32")
59+
per_part_size = size[0] // 2
5760
if rank == 0:
5861
param_attr = paddle.fluid.ParamAttr(
5962
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
60-
np_array[0:5, :]), )
63+
np_array[0:per_part_size, :]), )
6164
else:
6265
param_attr = paddle.fluid.ParamAttr(
6366
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
64-
np_array[5:10, :]), )
67+
np_array[per_part_size:size[0], :]), )
6568

6669
emb_out = paddle.distributed.split(
67-
data_in, (8, 8),
70+
data_in,
71+
size,
6872
operation="embedding",
6973
num_partitions=2,
7074
weight_attr=param_attr)

python/paddle/fluid/tests/unittests/parallel_embedding_api_none_divisible.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

python/paddle/fluid/tests/unittests/test_collective_api_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,10 @@ def check_with_place(self,
257257
elif col_type == "parallel_embedding":
258258
result_data = tr0_out[0]
259259
np.random.seed(2020)
260-
need_result = np.random.rand(10, 8)
260+
need_result = np.random.rand(12, 8)
261261
for i in range(result_data.shape[0]):
262262
for j in range(result_data.shape[1]):
263263
data = result_data[i][j]
264-
if data >= 4: data += 1
265264
assert np.allclose(
266265
tr0_out[1][i][j], need_result[data], atol=1e-08)
267266
elif col_type == "row_parallel_linear":

python/paddle/fluid/tests/unittests/test_collective_split_embedding_none_divisible.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,24 @@
1616
import unittest
1717
import numpy as np
1818
import paddle
19-
20-
from test_collective_api_base import TestDistBase
19+
from paddle.distributed import fleet
2120

2221
paddle.enable_static()
2322

2423

25-
class TestParallelEmbeddingNoneDivisibleAPI(TestDistBase):
26-
def _setup_config(self):
27-
pass
24+
class TestCollectiveSplitAssert(unittest.TestCase):
25+
def network(self):
26+
fleet.init()
27+
data = paddle.static.data(
28+
name='tindata', shape=[10, 1000], dtype="float32")
29+
emb_out = paddle.distributed.split(
30+
data, (7, 8), operation="embedding", num_partitions=2)
2831

29-
def test_parallel_embedding_none_divisible(self):
30-
self.check_with_place("parallel_embedding_api_none_divisible.py",
31-
"parallel_embedding", "nccl")
32+
def test_assert(self):
33+
with self.assertRaises(AssertionError):
34+
self.network()
3235

3336

3437
if __name__ == '__main__':
38+
paddle.enable_static()
3539
unittest.main()

0 commit comments

Comments
 (0)