Skip to content

Commit c112357

Browse files
committed
solve conflict
2 parents efbd3a4 + c294cca commit c112357

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1722
-228
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include <sys/time.h>
16+
17+
#include <iostream>
18+
#include <ostream>
19+
#include <string>
20+
#include <thread> // NOLINT
21+
#include <vector>
22+
23+
#include <ThreadPool.h>
24+
#include "boost/lexical_cast.hpp"
25+
#include "glog/logging.h"
26+
#include "paddle/fluid/distributed/common/utils.h"
27+
#include "paddle/fluid/framework/blocking_queue.h"
28+
#include "paddle/fluid/framework/dim.h"
29+
#include "paddle/fluid/framework/framework.pb.h"
30+
#include "paddle/fluid/framework/tensor.h"
31+
#include "paddle/fluid/framework/tensor_util.h"
32+
#include "paddle/fluid/string/split.h"
33+
34+
constexpr int FG = 256 * 1024 * 1024;
35+
constexpr int Q_SIZE = 10000;
36+
constexpr int BUCKET = 10;
37+
constexpr char XEOF[] = "EOF";
38+
39+
using boost::lexical_cast;
40+
41+
inline double GetCurrentUS() {
42+
struct timeval time;
43+
gettimeofday(&time, NULL);
44+
return 1e+6 * time.tv_sec + time.tv_usec;
45+
}
46+
47+
namespace paddle {
48+
namespace distributed {
49+
50+
class ShardingMerge {
51+
public:
52+
ShardingMerge() {}
53+
~ShardingMerge() {}
54+
55+
void Merge(const std::vector<std::string> &inputs,
56+
const std::vector<int64_t> &feasigns, const std::string &output,
57+
const int embedding_dim) {
58+
pool_.reset(new ::ThreadPool(inputs.size()));
59+
60+
std::vector<std::future<int>> tasks(inputs.size());
61+
std::vector<std::vector<int64_t>> rows;
62+
rows.resize(inputs.size());
63+
64+
auto begin = GetCurrentUS();
65+
for (int x = 0; x < inputs.size(); ++x) {
66+
tasks[x] = pool_->enqueue([this, x, &rows, &inputs, &feasigns]() -> int {
67+
DeserializeRowsFromFile(inputs[x], feasigns[x], &rows[x]);
68+
return 0;
69+
});
70+
}
71+
72+
for (size_t x = 0; x < tasks.size(); ++x) {
73+
tasks[x].wait();
74+
}
75+
76+
int64_t total_rows = 0;
77+
for (auto x = 0; x < rows.size(); x++) {
78+
total_rows += rows[x].size();
79+
}
80+
81+
auto end = GetCurrentUS();
82+
83+
VLOG(0) << "got " << total_rows
84+
<< " feasigin ids from sparse embedding using " << end - begin;
85+
86+
std::vector<int64_t> total_dims = {total_rows,
87+
static_cast<int64_t>(embedding_dim)};
88+
89+
std::vector<std::vector<int>> batch_buckets;
90+
batch_buckets.resize(inputs.size());
91+
92+
for (int x = 0; x < rows.size(); ++x) {
93+
batch_buckets[x] = bucket(rows[x].size(), BUCKET);
94+
}
95+
96+
std::ofstream out(output, std::ios::binary);
97+
98+
begin = GetCurrentUS();
99+
SerializeRowsToStream(out, rows, batch_buckets, total_rows);
100+
end = GetCurrentUS();
101+
VLOG(0) << "write rows to oostrream using " << end - begin;
102+
103+
begin = GetCurrentUS();
104+
SerializePreTensorToStream(out, total_dims);
105+
end = GetCurrentUS();
106+
VLOG(0) << "write pretensor to oostrream using " << end - begin;
107+
108+
begin = GetCurrentUS();
109+
SerializeValueToStream(out, inputs, batch_buckets, embedding_dim);
110+
end = GetCurrentUS();
111+
VLOG(0) << "write values to oostrream using " << end - begin;
112+
}
113+
114+
private:
115+
void SerializeRowsToStream(std::ostream &os,
116+
const std::vector<std::vector<int64_t>> &rows,
117+
const std::vector<std::vector<int>> &batch_buckets,
118+
int64_t total_rows) {
119+
{ // the 1st field, uint32_t version
120+
constexpr uint32_t version = 0;
121+
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
122+
}
123+
124+
{
125+
// the 2st field, rows information
126+
os.write(reinterpret_cast<const char *>(&total_rows), sizeof(total_rows));
127+
128+
for (int b = 0; b < BUCKET; ++b) {
129+
for (int x = 0; x < batch_buckets.size(); ++x) {
130+
auto begin = batch_buckets[x][b];
131+
auto end = batch_buckets[x][b + 1];
132+
133+
if (end - begin == 0) continue;
134+
135+
os.write(reinterpret_cast<const char *>(rows[x].data() + begin),
136+
sizeof(int64_t) * (end - begin));
137+
}
138+
}
139+
140+
// the 3st field, the height of SelectedRows
141+
int64_t height = total_rows;
142+
os.write(reinterpret_cast<const char *>(&height), sizeof(height));
143+
}
144+
}
145+
146+
void SerializePreTensorToStream(std::ostream &os,
147+
const std::vector<int64_t> &dims) {
148+
{ // the 1st field, uint32_t version
149+
constexpr uint32_t version = 0;
150+
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
151+
}
152+
{ // the 2nd field, tensor description
153+
// int32_t size
154+
framework::proto::VarType::TensorDesc desc;
155+
desc.set_data_type(framework::proto::VarType::FP32);
156+
auto *pb_dims = desc.mutable_dims();
157+
pb_dims->Resize(static_cast<int>(dims.size()), 0);
158+
std::copy(dims.begin(), dims.end(), pb_dims->begin());
159+
int32_t size = desc.ByteSize();
160+
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
161+
auto out = desc.SerializeAsString();
162+
os.write(out.data(), size);
163+
}
164+
}
165+
166+
void SerializeValueToVec(std::ifstream &in, const int batch,
167+
const int embedding_dim, std::vector<float> *out) {
168+
auto queue =
169+
std::make_shared<framework::BlockingQueue<std::vector<std::string>>>();
170+
171+
auto read = [batch, &in, &queue]() {
172+
std::string line;
173+
std::vector<std::string> columns;
174+
std::vector<std::string> values_str;
175+
176+
int count = 0;
177+
178+
while (std::getline(in, line)) {
179+
++count;
180+
columns = string::Split(line, '\t');
181+
182+
if (columns.size() != 5) {
183+
VLOG(0) << "unexpected line: " << line << ", skip it";
184+
continue;
185+
}
186+
187+
values_str = string::Split(columns[4], ',');
188+
queue->Push(values_str);
189+
190+
if (count >= batch) {
191+
break;
192+
}
193+
}
194+
queue->Push({});
195+
};
196+
197+
auto write = [embedding_dim, &out, &queue]() {
198+
std::vector<std::string> values_str;
199+
std::string line;
200+
201+
while (true) {
202+
queue->Pop(&values_str);
203+
204+
if (values_str.size() == 0) {
205+
break;
206+
}
207+
208+
for (int x = 0; x < embedding_dim; ++x) {
209+
float v = 0.0;
210+
try {
211+
v = lexical_cast<float>(values_str[x]);
212+
} catch (boost::bad_lexical_cast &e) {
213+
VLOG(0) << " get unexpected line: " << line;
214+
}
215+
out->push_back(v);
216+
}
217+
}
218+
};
219+
220+
std::thread p_read(read);
221+
std::thread p_write(write);
222+
p_read.join();
223+
p_write.join();
224+
}
225+
226+
void SerializeVecToStream(std::ostream &out,
227+
const std::vector<float> &value) {
228+
out.write(reinterpret_cast<const char *>(value.data()),
229+
static_cast<std::streamsize>(sizeof(float) * value.size()));
230+
}
231+
232+
void SerializeValueToStream(
233+
std::ostream &out, const std::vector<std::string> &ins,
234+
const std::vector<std::vector<int>> &batch_buckets,
235+
const int embedding_dim) {
236+
std::vector<std::shared_ptr<std::ifstream>> in_streams;
237+
238+
for (int x = 0; x < ins.size(); ++x) {
239+
in_streams.emplace_back(std::make_shared<std::ifstream>(ins[x]));
240+
}
241+
242+
std::vector<std::future<int>> tasks(ins.size());
243+
244+
for (int b = 0; b < BUCKET; ++b) {
245+
std::vector<std::vector<float>> values;
246+
values.resize(tasks.size());
247+
248+
auto begin = GetCurrentUS();
249+
250+
for (int x = 0; x < tasks.size(); ++x) {
251+
auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b];
252+
values[x].clear();
253+
values[x].reserve(batch * embedding_dim);
254+
}
255+
256+
for (int x = 0; x < tasks.size(); ++x) {
257+
tasks[x] =
258+
pool_->enqueue([this, b, x, &out, &in_streams, &batch_buckets,
259+
&values, embedding_dim]() -> int {
260+
auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b];
261+
if (batch == 0) return 0;
262+
SerializeValueToVec(*(in_streams[x].get()), batch, embedding_dim,
263+
&values[x]);
264+
return 0;
265+
});
266+
}
267+
268+
for (size_t x = 0; x < tasks.size(); ++x) {
269+
tasks[x].wait();
270+
}
271+
272+
auto end = GetCurrentUS();
273+
274+
auto begin1 = GetCurrentUS();
275+
for (size_t x = 0; x < tasks.size(); ++x) {
276+
SerializeVecToStream(out, values[x]);
277+
}
278+
auto end1 = GetCurrentUS();
279+
280+
VLOG(0) << "serialize buckets " << b << " read using " << end - begin
281+
<< ", to oostream using " << end1 - begin1;
282+
}
283+
}
284+
285+
void DeserializeRowsFromFile(const std::string &input_file,
286+
const int64_t feasigns,
287+
std::vector<int64_t> *rows) {
288+
std::string line;
289+
std::vector<std::string> columns;
290+
std::ifstream file(input_file);
291+
292+
rows->reserve(feasigns);
293+
294+
while (std::getline(file, line)) {
295+
columns = string::Split(line, '\t');
296+
if (columns.size() != 5) {
297+
VLOG(0) << "unexpected line: " << line << ", skip it";
298+
continue;
299+
}
300+
rows->push_back(std::stoull(columns[0]));
301+
}
302+
303+
VLOG(0) << "parse " << rows->size() << " embedding rows from "
304+
<< input_file;
305+
}
306+
307+
private:
308+
std::unique_ptr<::ThreadPool> pool_;
309+
};
310+
} // namespace distributed
311+
} // namespace paddle

paddle/fluid/distributed/common/utils.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <sys/time.h>
18+
1719
#include <functional>
1820
#include <memory>
1921
#include <string>
@@ -83,5 +85,11 @@ std::string to_string(const std::vector<T>& vec) {
8385
}
8486
return ss.str();
8587
}
88+
89+
inline double GetCurrentUS() {
90+
struct timeval time;
91+
gettimeofday(&time, NULL);
92+
return 1e+6 * time.tv_sec + time.tv_usec;
8693
}
87-
}
94+
} // namespace distributed
95+
} // namespace paddle

0 commit comments

Comments
 (0)