Skip to content

Commit 3b6e31b

Browse files
committed
Add CPU and NPU operator of c_embedding
1 parent b8e4ab4 commit 3b6e31b

File tree

10 files changed

+215
-102
lines changed

10 files changed

+215
-102
lines changed

paddle/fluid/operators/collective/c_allreduce_op.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx,
144144
try {
145145
const auto& runner_mean = paddle::operators::NpuOpRunner(
146146
"ReduceMeanD", {*in}, {mean}, {{"axes", axes}, {"keep_dims", false}});
147-
// FIXME(gongwb): not need to open this.
148-
// runner_mean.Run(stream);
147+
runner_mean.Run(stream);
149148
TensorToVector(mean, dev_ctx, &vec);
150149
} catch (...) {
151150
LOG(WARNING) << "ContainsNan catch exception";
@@ -241,7 +240,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
241240
case framework::proto::VarType::FP32: {
242241
if (FLAGS_hccl_check_nan) {
243242
VLOG(3) << "prepare to FoundNanInf";
244-
ContainsNan(*dev_ctx, dev_ctx->stream(), in);
243+
found_nan = ContainsNan(*dev_ctx, dev_ctx->stream(), in);
245244
VLOG(3) << "check_numerics:" << found_nan;
246245
}
247246
break;

paddle/fluid/operators/collective/c_embedding_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
7474
"(Tensor) The input represents embedding tensors, "
7575
"which is a learnable parameter.");
7676
AddInput("Ids",
77-
"An input with type int32 or int64"
77+
"An input with type int32 or int64 in CPU and GPU, int32 in NPU "
7878
"contains the ids to be looked up in W.");
7979
AddOutput("Out", "The lookup results, which have the same type as W.");
8080

@@ -126,16 +126,17 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel {
126126
// check valid
127127
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
128128
platform::errors::InvalidArgument(
129-
"npu only accept the dims of table_t == 2"));
129+
"Only accept the dims of table_t == 2"));
130130

131131
const int64_t start_idx = ctx->Attrs().Get<int64_t>("start_index");
132132
const int64_t height = table_dims[0];
133133
const int64_t width = table_dims[1];
134134

135135
PADDLE_ENFORCE_EQ(
136-
(height >= 0 && width >= 0 && start_idx >= 0), true,
137-
"height:%ld width:%ld start_idx:%ld must not have negtive values",
138-
height, width, start_idx);
136+
(height > 0 && width > 0 && start_idx >= 0), true,
137+
platform::errors::InvalidArgument(
138+
"height:%ld width:%ld start_idx:%ld must not have negtive values",
139+
height, width, start_idx));
139140
}
140141

141142
protected:

paddle/fluid/operators/collective/c_embedding_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
107107
limit);
108108
} else {
109109
PADDLE_THROW(platform::errors::Unavailable(
110-
"c_embedding ids only support int32 or int64."));
110+
"GPU c_embedding ids only support int32 or int64."));
111111
}
112112
}
113113
};

paddle/fluid/operators/collective/c_embedding_op.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ void GetIdsEmbedding(const TIds* ids, size_t ids_len, int64_t start_idx,
3838
int64_t local = id - start_idx;
3939

4040
if (local >= 0 && local < height) {
41-
/*
42-
for (int64_t w = 0; w < width; w++) {
43-
out[i * width + w] = table[local * width + w];
44-
}
45-
*/
41+
// for (int64_t w = 0; w < width; w++) {
42+
// out[i * width + w] = table[local * width + w];
43+
// }
44+
4645
memcpy(out + i * width, table + local * width, width * sizeof(TData));
46+
} else {
47+
memset(out + i * width, 0, width * sizeof(TData));
4748
}
4849
}
4950
}
@@ -74,7 +75,7 @@ class CEmbeddingOpCPUKernel : public framework::OpKernel<T> {
7475
table_data, height, width, output_data);
7576
} else {
7677
PADDLE_THROW(platform::errors::Unavailable(
77-
"c_embedding ids only support int32 or int64."));
78+
"CPU c_embedding ids only support int32 or int64."));
7879
}
7980
}
8081
};
@@ -108,12 +109,17 @@ class CEmbeddingGradOpCPUKernel : public framework::OpKernel<T> {
108109
T* table_grad_data =
109110
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
110111

112+
size_t table_t_mem_size =
113+
table_t->numel() * framework::SizeOfType(table_grad_t->type());
114+
size_t table_grad_t_mem_size =
115+
table_grad_t->numel() * framework::SizeOfType(table_grad_t->type());
116+
111117
VLOG(10) << "table_dims:" << table_t->dims()
112-
<< ", table_t memory_size:" << table_t->memory_size()
113-
<< ", table_grad_t memory_size:" << table_grad_t->memory_size()
118+
<< ", table_t memory_size:" << table_t_mem_size
119+
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
114120
<< ", start_index:" << start_idx;
115121

116-
memset(table_grad_data, 0, table_grad_t->memory_size());
122+
memset(table_grad_data, 0, table_grad_t_mem_size);
117123
const T* d_output_data = d_output_t->data<T>();
118124

119125
const int64_t height = table_t->dims()[0];
@@ -128,7 +134,7 @@ class CEmbeddingGradOpCPUKernel : public framework::OpKernel<T> {
128134
table_grad_data, height, width, d_output_data);
129135
} else {
130136
PADDLE_THROW(platform::errors::Unavailable(
131-
"c_embedding ids only support int32 or int64."));
137+
"CPU c_embedding ids only support int32 or int64."));
132138
}
133139
}
134140
};

paddle/fluid/operators/collective/c_embedding_op_npu.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ void NPUGetIdsEmbedding(const framework::ExecutionContext &context) {
113113
framework::make_ddim({table_t->dims()[0] + 1, table_t->dims()[1]});
114114
framework::LoDTensor table_t_pad;
115115

116-
size_t mem_size = table_t->memory_size();
116+
size_t mem_size = table_t->numel() * framework::SizeOfType(table_t->type());
117117
size_t line_mem_size =
118118
table_t->dims()[1] * framework::SizeOfType(table_t->type());
119-
PADDLE_ENFORCE_EQ(line_mem_size % 64, 0, "must align by 64");
119+
PADDLE_ENFORCE_EQ(line_mem_size % 64, 0,
120+
platform::errors::InvalidArgument(
121+
"NPU only accept the second dim must align by 64"));
120122

121123
VLOG(10) << "mem_size:" << mem_size << ",line_mem_size:" << line_mem_size
122124
<< ", pad_shape:" << pad_shape << ", table_dims:" << table_t->dims();
@@ -148,11 +150,9 @@ class CEmbeddingNPUKernel : public framework::OpKernel<T> {
148150
const auto &index_type = ids_t->type();
149151
if (index_type == framework::proto::VarType::INT32) {
150152
NPUGetIdsEmbedding<int32_t, T>(context);
151-
} else if (index_type == framework::proto::VarType::INT64) {
152-
NPUGetIdsEmbedding<int64_t, T>(context);
153153
} else {
154154
PADDLE_THROW(platform::errors::Unavailable(
155-
"c_embedding ids only support int32 or int64."));
155+
"NPU c_embedding ids only support int32."));
156156
}
157157
}
158158
};
@@ -186,9 +186,10 @@ void NPUUpdateEmbedding(const framework::ExecutionContext &context) {
186186
// set table_t_pad to zero
187187
uint8_t *pad_data = reinterpret_cast<uint8_t *>(
188188
table_t_pad.mutable_data<T>(pad_shape, context.GetPlace()));
189-
PADDLE_ENFORCE_NPU_SUCCESS(
190-
aclrtMemsetAsync(pad_data, table_t_pad.memory_size(), 0,
191-
table_t_pad.memory_size(), stream));
189+
size_t table_t_pad_mem_size =
190+
table_t_pad.numel() * framework::SizeOfType(table_t_pad.type());
191+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtMemsetAsync(pad_data, table_t_pad_mem_size, 0,
192+
table_t_pad_mem_size, stream));
192193

193194
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
194195
// can be different tensor, but in cann 20.2+, it does inplace operation.
@@ -200,12 +201,15 @@ void NPUUpdateEmbedding(const framework::ExecutionContext &context) {
200201

201202
// copy table_t_pad to table_t
202203
T *dst = table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
203-
const size_t mem_size = table_grad_t->memory_size();
204+
const size_t mem_size =
205+
table_grad_t->numel() * framework::SizeOfType(table_grad_t->type());
204206

205207
// check align
206208
size_t line_mem_size =
207209
table_grad_t->dims()[1] * framework::SizeOfType(table_grad_t->type());
208-
PADDLE_ENFORCE_EQ(line_mem_size % 64, 0, "must align by 64");
210+
PADDLE_ENFORCE_EQ(line_mem_size % 64, 0,
211+
platform::errors::InvalidArgument(
212+
"NPU only accept the second dim must align by 64"));
209213

210214
PADDLE_ENFORCE_NPU_SUCCESS(aclrtMemcpyAsync(
211215
dst, mem_size, pad_data, mem_size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
@@ -220,11 +224,9 @@ class CEmbeddingGradNPUKernel : public framework::OpKernel<T> {
220224
const auto &index_type = ids_t->type();
221225
if (index_type == framework::proto::VarType::INT32) {
222226
NPUUpdateEmbedding<int32_t, T>(context);
223-
} else if (index_type == framework::proto::VarType::INT64) {
224-
NPUUpdateEmbedding<int64_t, T>(context);
225227
} else {
226-
PADDLE_THROW(platform::errors::Unavailable(
227-
"c_embedding ids only support int32 or int64."));
228+
PADDLE_THROW(
229+
platform::errors::Unavailable("c_embedding ids only support int32."));
228230
}
229231
}
230232
};

paddle/fluid/platform/flags.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ DEFINE_string(selected_npus, "",
9393
"This option is useful when doing multi process training and "
9494
"each process have only one device (NPU). If you want to use "
9595
"all visible devices, set this to empty string.");
96-
DEFINE_bool(hccl_check_nan, true,
96+
DEFINE_bool(hccl_check_nan, false,
9797
"Check Nan in tensor before hccl_allreduce_sum otherwise it'll "
9898
"core when meets Nan value");
9999
DEFINE_string(

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
9191
LIST(REMOVE_ITEM TEST_OPS test_c_split)
9292
LIST(REMOVE_ITEM TEST_OPS test_allgather)
9393
LIST(REMOVE_ITEM TEST_OPS test_c_identity)
94+
LIST(REMOVE_ITEM TEST_OPS test_c_embedding_op)
9495
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
9596
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
9697
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce)
@@ -119,10 +120,6 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
119120
LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler)
120121
endif()
121122

122-
if(((NOT WITH_ROCM) AND (NOT WITH_GPU) AND (NOT WITH_ASCEND_CL)) OR WIN32)
123-
LIST(REMOVE_ITEM TEST_OPS test_c_embedding_op)
124-
endif()
125-
126123
if(WIN32)
127124
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
128125
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from op_test import OpTest
20+
import paddle
21+
import paddle.fluid as fluid
22+
from paddle.framework import core
23+
24+
SEED = 2021
25+
np.random.seed(SEED)
26+
27+
28+
def get_c_embedding(start, end, table, ids):
29+
index = ids.flatten()
30+
input_mask = (index < start) | (index >= end)
31+
masked_input = index - start
32+
masked_input[input_mask] = 0
33+
output = table[masked_input]
34+
output[input_mask] = 0.0
35+
return output
36+
37+
38+
class TestCEmbeddingCPU(OpTest):
39+
def setUp(self):
40+
self.init_dtype()
41+
self.initcase()
42+
if core.is_compiled_with_npu():
43+
self.__class__.use_npu = True
44+
elif core.is_compiled_with_cuda():
45+
self.__class__.exist_fp64_check_grad = True
46+
47+
def initcase(self):
48+
self.op_type = "c_embedding"
49+
table = np.random.random((17, 64)).astype(self.dtype)
50+
ids = np.random.randint(
51+
low=0, high=17 * 2, size=(2, 4)).astype(self.ids_dtype)
52+
self.start_index = 10
53+
self.end_index = self.start_index + 17
54+
55+
self.inputs = {'W': table, 'Ids': ids}
56+
np_out = get_c_embedding(self.start_index, self.end_index, table, ids)
57+
self.outputs = {'Out': np_out.reshape((2, 4, 64))}
58+
self.attrs = {'start_index': self.start_index}
59+
if core.is_compiled_with_npu():
60+
self.__class__.use_npu = True
61+
62+
def test_check_cpu(self):
63+
self.check_output_with_place(core.CPUPlace())
64+
65+
def test_check_cpu_grad(self):
66+
self.check_grad_with_place(core.CPUPlace(), ['W'], 'Out')
67+
68+
def init_dtype(self):
69+
self.dtype = "float32"
70+
self.ids_dtype = "int64"
71+
72+
73+
class TestCEmbeddingOpBase(TestCEmbeddingCPU):
74+
def setUp(self):
75+
self.init_dtype()
76+
self.initcase()
77+
78+
def test_check_output(self):
79+
if core.is_compiled_with_cuda():
80+
self.check_output_with_place(core.CUDAPlace(0))
81+
elif core.is_compiled_with_npu():
82+
self.check_output_with_place(core.NPUPlace(0))
83+
84+
def test_check_grad(self):
85+
if core.is_compiled_with_cuda():
86+
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
87+
elif core.is_compiled_with_npu():
88+
self.check_grad_with_place(core.NPUPlace(0), ['W'], 'Out')
89+
90+
def init_dtype(self):
91+
if core.is_compiled_with_cuda():
92+
self.dtype = "float64"
93+
self.ids_dtype = "int64"
94+
elif core.is_compiled_with_npu():
95+
self.dtype = "float32"
96+
self.ids_dtype = "int32"
97+
98+
99+
class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):
100+
def setUp(self):
101+
self.init_dtype()
102+
self.initcase()
103+
104+
def initcase(self):
105+
self.op_type = "c_embedding"
106+
table = np.random.random((17, 64)).astype(self.dtype)
107+
ids = np.random.randint(
108+
low=0, high=17 * 2, size=(2, 4)).astype(self.ids_dtype)
109+
self.start_index = 10
110+
ids[0][1] = 12
111+
ids[0][2] = 12
112+
ids[1][2] = 12
113+
ids[1][3] = 12
114+
self.end_index = self.start_index + 17
115+
116+
self.inputs = {'W': table, 'Ids': ids}
117+
np_out = get_c_embedding(self.start_index, self.end_index, table, ids)
118+
self.outputs = {'Out': np_out.reshape((2, 4, 64))}
119+
self.attrs = {'start_index': self.start_index}
120+
121+
if core.is_compiled_with_npu():
122+
self.__class__.use_npu = True
123+
elif core.is_compiled_with_cuda():
124+
self.__class__.exist_fp64_check_grad = True
125+
126+
def init_dtype(self):
127+
self.dtype = "float32"
128+
self.ids_dtype = "int32"
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
from paddle.fluid.tests.unittests.c_embedding_op_base import TestCEmbeddingCPU, TestCEmbeddingOpBase, TestCEmbeddingOpFP32
26+
27+
paddle.enable_static()
28+
29+
TestCEmbeddingCPU()
30+
31+
TestCEmbeddingOpBase()
32+
33+
TestCEmbeddingOpFP32()
34+
35+
if __name__ == "__main__":
36+
unittest.main()

0 commit comments

Comments
 (0)