Skip to content

Commit c1f688c

Browse files
committed
fix bug of split
1 parent 5c8849f commit c1f688c

File tree

6 files changed

+161
-203
lines changed

6 files changed

+161
-203
lines changed

paddle/fluid/operators/collective/c_embedding_op.cc

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ class CEmbeddingOp : public framework::OperatorWithKernel {
2626
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "CEmbeddingOp");
2727
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CEmbeddingOp");
2828

29-
// auto start_index = ctx->Attrs().Get<int64_t>("start_index");
30-
// auto end_index = ctx->Attrs().Get<int64_t>("end_index");
31-
3229
auto table_dims = ctx->GetInputDim("W");
3330
auto ids_dims = ctx->GetInputDim("Ids");
3431
int ids_rank = ids_dims.size();
@@ -42,15 +39,6 @@ class CEmbeddingOp : public framework::OperatorWithKernel {
4239
"c_embedding's shape = [%s].",
4340
table_dims.size(), table_dims));
4441

45-
// PADDLE_ENFORCE_EQ(
46-
// end_index - start_index, table_dims[0],
47-
// platform::errors::InvalidArgument(
48-
// "The value of end_index - start_index should be equal to table's
49-
// length."
50-
// "But received end_index - start_index = %d, "
51-
// "table's length = %d.",
52-
// end_index - start_index, table_dims[0]));
53-
5442
auto output_dims = framework::vectorize(ids_dims);
5543
output_dims.push_back(table_dims[1]);
5644
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
@@ -81,20 +69,9 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
8169
AddOutput("Out", "The lookup results, which have the same type as W.");
8270

8371
AddAttr<int64_t>("start_index",
84-
"(int64, default 0) "
85-
"If the value is 0, it makes no effect to lookup. "
86-
"Otherwise the given value indicates padding the output "
87-
"with zeros whenever lookup encounters it in Ids.")
72+
"(int64, default 0), The starting index is indeed, "
73+
"and the out-of-bounds will be set to 0 ")
8874
.SetDefault(0);
89-
90-
// AddAttr<int64_t>("end_index",
91-
// "(int64, default -1) "
92-
// "If the value is -1, it makes no effect to lookup. "
93-
// "Otherwise the given value indicates padding the output
94-
// "
95-
// "with zeros whenever lookup encounters it in Ids.")
96-
// .SetDefault(1);
97-
9875
AddComment(R"DOC(
9976
c_embedding Operator.
10077
@@ -153,7 +130,6 @@ class CEmbeddingOpGradVarTypeInference : public framework::VarTypeInference {
153130
VLOG(3) << "c_embedding_grad op " << framework::GradVarName("W")
154131
<< " is set to LoDTensor";
155132
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
156-
// }
157133
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
158134
}
159135
};

paddle/fluid/operators/collective/c_embedding_op.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
8282
const auto &dev_ctx =
8383
context.template device_context<platform::CUDADeviceContext>();
8484
const int64_t start_idx = context.Attr<int64_t>("start_index");
85-
// const int64_t end_idx = context.Attr<int64_t>("end_index");
86-
8785
size_t N = table_t->dims()[0];
8886
size_t D = table_t->dims()[1];
8987
size_t K = ids_t->numel();
@@ -118,8 +116,6 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
118116
const auto &dev_ctx =
119117
context.template device_context<platform::CUDADeviceContext>();
120118
const int64_t start_idx = context.Attr<int64_t>("start_index");
121-
// const int64_t end_idx = context.Attr<int64_t>("end_index");
122-
123119
auto ids_t = context.Input<LoDTensor>("Ids");
124120
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
125121
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));

python/paddle/distributed/collective.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def _c_identity(tensor, group=None):
775775
return out
776776

777777

778-
def _c_concat(tensor, nranks, group=None):
778+
def _c_concat(tensor, group=None):
779779
"""
780780
Return allgather of the tensor, mainly used with model parallel.
781781
@@ -791,10 +791,14 @@ def _c_concat(tensor, nranks, group=None):
791791
return
792792
ring_id = 0 if group is None else group.id
793793

794+
global_rank = _get_global_env().rank
795+
rank = global_rank if group is None else group.get_group_rank(global_rank)
796+
nranks = _get_global_env().world_size if group is None else group.nranks
797+
794798
if in_dygraph_mode():
795799
return core.ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
796-
True, 'nranks', nranks, 'use_model_parallel',
797-
True)
800+
True, 'rank', rank, 'nranks', nranks,
801+
'use_model_parallel', True)
798802

799803
op_type = 'c_concat'
800804
helper = LayerHelper(op_type, **locals())
@@ -812,12 +816,13 @@ def _c_concat(tensor, nranks, group=None):
812816
'ring_id': ring_id,
813817
'use_calc_stream': True,
814818
'use_model_parallel': True,
815-
'nranks': nranks
819+
'nranks': nranks,
820+
'rank': rank
816821
})
817822
return out
818823

819824

820-
def _c_split(tensor, rank, nranks, group=None):
825+
def _c_split(tensor, group=None):
821826
"""
822827
Split tensor evenly among all members, mainly used with model parallel.
823828
@@ -834,6 +839,10 @@ def _c_split(tensor, rank, nranks, group=None):
834839
return
835840
ring_id = 0 if group is None else group.id
836841

842+
global_rank = _get_global_env().rank
843+
rank = global_rank if group is None else group.get_group_rank(global_rank)
844+
nranks = _get_global_env().world_size if group is None else group.nranks
845+
837846
if in_dygraph_mode():
838847
return core.ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
839848
ring_id, 'rank', rank, 'nranks', nranks,
@@ -884,11 +893,10 @@ def _mp_allreduce(tensor,
884893

885894

886895
def _c_embedding(x, weight, start_index=0, name=None):
887-
888896
if in_dygraph_mode():
889897
return core.ops.c_embedding(weight, x, "start_index", start_index)
890898
else:
891-
helper = LayerHelper('_c_embedding', **locals())
899+
helper = LayerHelper('c_embedding', **locals())
892900
dtype = helper.input_dtype(input_param_name='weight')
893901

894902
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'embedding')
@@ -1008,7 +1016,7 @@ def _parallel_linear(x,
10081016

10091017
if axis == 0:
10101018
if split_tensor:
1011-
x = _c_split(x, inner_rank, nranks, group=group)
1019+
x = _c_split(x, group=group)
10121020
else:
10131021
x = _c_identity(x, group=group)
10141022

python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,6 @@
2727
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
2828

2929

30-
class _EmbeddingInModelParallel(PyLayer):
31-
@staticmethod
32-
def forward(ctx, masked_input, weight, input_mask, name):
33-
output_parallel = F.embedding(
34-
masked_input,
35-
weight=weight,
36-
padding_idx=None,
37-
sparse=False,
38-
name=name)
39-
# Mask the output embedding.
40-
output_parallel[input_mask, :] = 0.0
41-
42-
ctx.save_for_backward(output_parallel, input_mask)
43-
44-
return output_parallel
45-
46-
@staticmethod
47-
def backward(ctx, dout):
48-
output_parallel, input_mask = ctx.saved_tensor()
49-
paddle.autograd.backward(tensors=[output_parallel], grad_tensors=[dout])
50-
output_parallel.grad[input_mask, :] = 0
51-
52-
return None, output_parallel.grad, None
53-
54-
5530
class VocabParallelEmbedding(Layer):
5631
def __init__(self,
5732
num_embeddings,
@@ -76,9 +51,6 @@ def __init__(self,
7651
per_part_size = num_embeddings // self.world_size
7752

7853
self.vocab_start_index = self.rank * per_part_size
79-
self.vocab_end_index = self.vocab_start_index + per_part_size
80-
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
81-
8254
self._dtype = self._helper.get_default_dtype()
8355
self._size = [per_part_size, embedding_dim]
8456
self._weight_attr = weight_attr
@@ -92,32 +64,25 @@ def __init__(self,
9264
is_bias=False)
9365
self.weight.is_distributed = True
9466

95-
def forward(self, input_):
67+
def forward(self, x):
9668
if self.is_mp:
97-
# Build the mask.
98-
input_mask = paddle.logical_or((input_ < self.vocab_start_index),
99-
(input_ >= self.vocab_end_index))
100-
# Mask the input.
101-
masked_input = input_.clone() - self.vocab_start_index
102-
masked_input[input_mask] = 0
69+
output_parallel = paddle.distributed.collective._c_embedding(
70+
x,
71+
self.weight,
72+
start_index=self.vocab_start_index,
73+
name=self._name)
74+
output = paddle.distributed.collective._mp_allreduce(
75+
output_parallel,
76+
group=self.model_parallel_group,
77+
use_calc_stream=True,
78+
use_model_parallel=True)
10379
else:
104-
masked_input = input_
105-
106-
output_parallel = F.embedding(
107-
masked_input,
108-
weight=self.weight,
109-
padding_idx=None,
110-
sparse=False,
111-
name=self._name)
112-
# Mask the output embedding.
113-
if self.is_mp:
114-
output_parallel[input_mask, :] = 0.0
115-
116-
output = paddle.distributed.collective._mp_allreduce(
117-
output_parallel,
118-
group=self.model_parallel_group,
119-
use_calc_stream=True,
120-
use_model_parallel=True)
80+
output = F.embedding(
81+
x,
82+
weight=self.weight,
83+
padding_idx=None,
84+
sparse=False,
85+
name=self._name)
12186
return output
12287

12388

@@ -188,9 +153,7 @@ def forward(self, x):
188153

189154
if self.gather_output and self.is_mp:
190155
output = paddle.distributed.collective._c_concat(
191-
output_parallel,
192-
nranks=self.world_size,
193-
group=self.model_parallel_group)
156+
output_parallel, group=self.model_parallel_group)
194157
else:
195158
output = output_parallel
196159
return output
@@ -258,10 +221,7 @@ def forward(self, x):
258221
else:
259222
# split last dim
260223
input_parallel = paddle.distributed.collective._c_split(
261-
x,
262-
rank=self.rank,
263-
nranks=self.world_size,
264-
group=self.model_parallel_group)
224+
x, group=self.model_parallel_group)
265225

266226
output_parallel = F.linear(input_parallel, self.weight, name=self._name)
267227

0 commit comments

Comments
 (0)