@@ -894,8 +894,25 @@ def _mp_allreduce(tensor,
894894 "use_model_parallel" , use_model_parallel )
895895 else :
896896 raise ValueError ("Unknown parameter: {}." .format (op ))
897- else :
898- raise NotImplementedError ("No support _mp_allreduce in dygraph mode." )
897+
898+ op_type = 'c_allreduce_sum'
899+ helper = LayerHelper (op_type , ** locals ())
900+ out = helper .create_variable_for_type_inference (dtype = tensor .dtype )
901+
902+ check_variable_and_dtype (
903+ tensor , 'tensor' , ['float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
904+ op_type )
905+
906+ helper .append_op (
907+ type = op_type ,
908+ inputs = {'X' : tensor },
909+ outputs = {'Out' : out },
910+ attrs = {
911+ 'ring_id' : ring_id ,
912+ 'use_calc_stream' : use_calc_stream ,
913+ 'use_model_parallel' : use_model_parallel ,
914+ })
915+ return out
899916
900917
901918def _c_lookup_table (table , index , start_index = 0 , name = None ):
@@ -915,6 +932,19 @@ def _c_lookup_table(table, index, start_index=0, name=None):
915932 if in_dygraph_mode ():
916933 return core .ops .c_embedding (table , index , "start_index" , start_index )
917934
935+ op_type = 'c_embedding'
936+ helper = LayerHelper (op_type , ** locals ())
937+ dtype = helper .input_dtype (input_param_name = 'table' )
938+ check_variable_and_dtype (index , 'input' , ['int32' , 'int64' ], op_type )
939+ tmp = helper .create_variable_for_type_inference (dtype )
940+ helper .append_op (
941+ type = 'c_embedding' ,
942+ inputs = {'Ids' : index ,
943+ 'W' : table },
944+ outputs = {'Out' : tmp },
945+ attrs = {"start_index" : start_index })
946+ return tmp
947+
918948
919949class _Linear (layers .Layer ):
920950 """
@@ -1136,47 +1166,34 @@ def _parallel_embedding(x,
11361166 return
11371167 ring_id = 0 if group is None else group .id
11381168
1139- origin_num_embeddings = origin_size [0 ]
1140- embedding = paddle .nn .Embedding (
1141- per_part_embeddings ,
1142- origin_size [1 ],
1143- padding_idx = per_part_embeddings - 1 ,
1144- sparse = False ,
1145- weight_attr = param_attr ,
1146- name = name )
1147-
1148- origin_input_shape = x .shape
1149- if len (origin_input_shape ) == 2 :
1150- x = paddle .unsqueeze (x , axis = - 1 )
1151- else :
1152- assert origin_input_shape [- 1 ] == 1 , (
1153- "The last dimension size of x must be 1." )
1154- x_shard = paddle .shard_index (x , origin_num_embeddings , num_partitions ,
1155- inner_rank , per_part_embeddings - 1 )
1156- if len (origin_input_shape ) == 2 :
1157- x_shard = paddle .squeeze (x_shard , axis = - 1 )
1158- emb_out = embedding (x_shard )
1169+ helper = LayerHelper ("_parallel_embedding" , ** locals ())
1170+
1171+ per_part_size = per_part_embeddings
1172+ rank = inner_rank
1173+
1174+ vocab_start_index = rank * per_part_size
1175+ dtype = helper .get_default_dtype ()
1176+ size = [per_part_size , origin_size [1 ]]
1177+
1178+ weight = helper .create_parameter (
1179+ attr = param_attr , shape = size , dtype = dtype , is_bias = False )
1180+
1181+ if num_partitions == 1 :
1182+ return paddle .nn .functional .embedding (
1183+ x , weight = weight , padding_idx = None , sparse = False , name = name )
1184+
11591185 startup_block = paddle .static .default_startup_program ().global_block ()
11601186 main_block = paddle .static .default_main_program ().global_block ()
1161- startup_block .vars [embedding .weight .name ].is_distributed = True
1162- main_block .vars [embedding .weight .name ].is_distributed = True
1163- out = main_block .create_var (
1164- shape = emb_out .shape ,
1165- dtype = emb_out .dtype ,
1166- type = emb_out .type ,
1167- lod_level = emb_out .lod_level ,
1168- persistable = False ,
1169- is_data = False ,
1170- need_check_feed = emb_out .desc .need_check_feed ())
1171- main_block .append_op (
1172- type = 'c_allreduce_sum' ,
1173- inputs = {'X' : emb_out },
1174- outputs = {'Out' : out },
1175- attrs = {
1176- 'ring_id' : ring_id ,
1177- 'use_calc_stream' : True ,
1178- 'use_model_parallel' : True
1179- })
1187+ startup_block .vars [weight .name ].is_distributed = True
1188+ main_block .vars [weight .name ].is_distributed = True
1189+
1190+ output_parallel = paddle .distributed .collective ._c_lookup_table (
1191+ weight , x , start_index = vocab_start_index , name = name )
1192+ out = paddle .distributed .collective ._mp_allreduce (
1193+ output_parallel ,
1194+ group = group ,
1195+ use_calc_stream = True ,
1196+ use_model_parallel = True )
11801197 return out
11811198
11821199
@@ -1288,11 +1305,11 @@ def split(x,
12881305 if operation == "embedding" :
12891306 assert axis == 0 , ("We only support to split the weight of embedding "
12901307 "along the first axis now." )
1291- per_part_size = (size [0 ] + num_partitions - 1 ) // num_partitions
1292- last_part_size = size [0 ] - per_part_size * (num_partitions - 1 )
1293- if inner_rank == num_partitions - 1 : per_part_size = last_part_size
1294- per_part_size += 1 # make the last row as the padding index
1308+ assert size [0 ] % num_partitions == 0 , \
1309+ "The length of the vocabulary must be divisible by num_partitions " \
1310+ "but received vocabulary={} num_partitions={}" .format (size [0 ], num_partitions )
12951311
1312+ per_part_size = size [0 ] // num_partitions
12961313 emb_out = _parallel_embedding (
12971314 x ,
12981315 per_part_size ,
0 commit comments