|
| 1 | +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +#pragma once |
| 15 | +#include <sys/time.h> |
| 16 | + |
| 17 | +#include <iostream> |
| 18 | +#include <ostream> |
| 19 | +#include <string> |
| 20 | +#include <thread> // NOLINT |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#include <ThreadPool.h> |
| 24 | +#include "boost/lexical_cast.hpp" |
| 25 | +#include "glog/logging.h" |
| 26 | +#include "paddle/fluid/distributed/common/utils.h" |
| 27 | +#include "paddle/fluid/framework/blocking_queue.h" |
| 28 | +#include "paddle/fluid/framework/dim.h" |
| 29 | +#include "paddle/fluid/framework/framework.pb.h" |
| 30 | +#include "paddle/fluid/framework/tensor.h" |
| 31 | +#include "paddle/fluid/framework/tensor_util.h" |
| 32 | +#include "paddle/fluid/string/split.h" |
| 33 | + |
| 34 | +constexpr int FG = 256 * 1024 * 1024; |
| 35 | +constexpr int Q_SIZE = 10000; |
| 36 | +constexpr int BUCKET = 10; |
| 37 | +constexpr char XEOF[] = "EOF"; |
| 38 | + |
| 39 | +using boost::lexical_cast; |
| 40 | + |
| 41 | +inline double GetCurrentUS() { |
| 42 | + struct timeval time; |
| 43 | + gettimeofday(&time, NULL); |
| 44 | + return 1e+6 * time.tv_sec + time.tv_usec; |
| 45 | +} |
| 46 | + |
| 47 | +namespace paddle { |
| 48 | +namespace distributed { |
| 49 | + |
| 50 | +class ShardingMerge { |
| 51 | + public: |
| 52 | + ShardingMerge() {} |
| 53 | + ~ShardingMerge() {} |
| 54 | + |
| 55 | + void Merge(const std::vector<std::string> &inputs, |
| 56 | + const std::vector<int64_t> &feasigns, const std::string &output, |
| 57 | + const int embedding_dim) { |
| 58 | + pool_.reset(new ::ThreadPool(inputs.size())); |
| 59 | + |
| 60 | + std::vector<std::future<int>> tasks(inputs.size()); |
| 61 | + std::vector<std::vector<int64_t>> rows; |
| 62 | + rows.resize(inputs.size()); |
| 63 | + |
| 64 | + auto begin = GetCurrentUS(); |
| 65 | + for (int x = 0; x < inputs.size(); ++x) { |
| 66 | + tasks[x] = pool_->enqueue([this, x, &rows, &inputs, &feasigns]() -> int { |
| 67 | + DeserializeRowsFromFile(inputs[x], feasigns[x], &rows[x]); |
| 68 | + return 0; |
| 69 | + }); |
| 70 | + } |
| 71 | + |
| 72 | + for (size_t x = 0; x < tasks.size(); ++x) { |
| 73 | + tasks[x].wait(); |
| 74 | + } |
| 75 | + |
| 76 | + int64_t total_rows = 0; |
| 77 | + for (auto x = 0; x < rows.size(); x++) { |
| 78 | + total_rows += rows[x].size(); |
| 79 | + } |
| 80 | + |
| 81 | + auto end = GetCurrentUS(); |
| 82 | + |
| 83 | + VLOG(0) << "got " << total_rows |
| 84 | + << " feasigin ids from sparse embedding using " << end - begin; |
| 85 | + |
| 86 | + std::vector<int64_t> total_dims = {total_rows, |
| 87 | + static_cast<int64_t>(embedding_dim)}; |
| 88 | + |
| 89 | + std::vector<std::vector<int>> batch_buckets; |
| 90 | + batch_buckets.resize(inputs.size()); |
| 91 | + |
| 92 | + for (int x = 0; x < rows.size(); ++x) { |
| 93 | + batch_buckets[x] = bucket(rows[x].size(), BUCKET); |
| 94 | + } |
| 95 | + |
| 96 | + std::ofstream out(output, std::ios::binary); |
| 97 | + |
| 98 | + begin = GetCurrentUS(); |
| 99 | + SerializeRowsToStream(out, rows, batch_buckets, total_rows); |
| 100 | + end = GetCurrentUS(); |
| 101 | + VLOG(0) << "write rows to oostrream using " << end - begin; |
| 102 | + |
| 103 | + begin = GetCurrentUS(); |
| 104 | + SerializePreTensorToStream(out, total_dims); |
| 105 | + end = GetCurrentUS(); |
| 106 | + VLOG(0) << "write pretensor to oostrream using " << end - begin; |
| 107 | + |
| 108 | + begin = GetCurrentUS(); |
| 109 | + SerializeValueToStream(out, inputs, batch_buckets, embedding_dim); |
| 110 | + end = GetCurrentUS(); |
| 111 | + VLOG(0) << "write values to oostrream using " << end - begin; |
| 112 | + } |
| 113 | + |
| 114 | + private: |
| 115 | + void SerializeRowsToStream(std::ostream &os, |
| 116 | + const std::vector<std::vector<int64_t>> &rows, |
| 117 | + const std::vector<std::vector<int>> &batch_buckets, |
| 118 | + int64_t total_rows) { |
| 119 | + { // the 1st field, uint32_t version |
| 120 | + constexpr uint32_t version = 0; |
| 121 | + os.write(reinterpret_cast<const char *>(&version), sizeof(version)); |
| 122 | + } |
| 123 | + |
| 124 | + { |
| 125 | + // the 2st field, rows information |
| 126 | + os.write(reinterpret_cast<const char *>(&total_rows), sizeof(total_rows)); |
| 127 | + |
| 128 | + for (int b = 0; b < BUCKET; ++b) { |
| 129 | + for (int x = 0; x < batch_buckets.size(); ++x) { |
| 130 | + auto begin = batch_buckets[x][b]; |
| 131 | + auto end = batch_buckets[x][b + 1]; |
| 132 | + |
| 133 | + if (end - begin == 0) continue; |
| 134 | + |
| 135 | + os.write(reinterpret_cast<const char *>(rows[x].data() + begin), |
| 136 | + sizeof(int64_t) * (end - begin)); |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + // the 3st field, the height of SelectedRows |
| 141 | + int64_t height = total_rows; |
| 142 | + os.write(reinterpret_cast<const char *>(&height), sizeof(height)); |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + void SerializePreTensorToStream(std::ostream &os, |
| 147 | + const std::vector<int64_t> &dims) { |
| 148 | + { // the 1st field, uint32_t version |
| 149 | + constexpr uint32_t version = 0; |
| 150 | + os.write(reinterpret_cast<const char *>(&version), sizeof(version)); |
| 151 | + } |
| 152 | + { // the 2nd field, tensor description |
| 153 | + // int32_t size |
| 154 | + framework::proto::VarType::TensorDesc desc; |
| 155 | + desc.set_data_type(framework::proto::VarType::FP32); |
| 156 | + auto *pb_dims = desc.mutable_dims(); |
| 157 | + pb_dims->Resize(static_cast<int>(dims.size()), 0); |
| 158 | + std::copy(dims.begin(), dims.end(), pb_dims->begin()); |
| 159 | + int32_t size = desc.ByteSize(); |
| 160 | + os.write(reinterpret_cast<const char *>(&size), sizeof(size)); |
| 161 | + auto out = desc.SerializeAsString(); |
| 162 | + os.write(out.data(), size); |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + void SerializeValueToVec(std::ifstream &in, const int batch, |
| 167 | + const int embedding_dim, std::vector<float> *out) { |
| 168 | + auto queue = |
| 169 | + std::make_shared<framework::BlockingQueue<std::vector<std::string>>>(); |
| 170 | + |
| 171 | + auto read = [batch, &in, &queue]() { |
| 172 | + std::string line; |
| 173 | + std::vector<std::string> columns; |
| 174 | + std::vector<std::string> values_str; |
| 175 | + |
| 176 | + int count = 0; |
| 177 | + |
| 178 | + while (std::getline(in, line)) { |
| 179 | + ++count; |
| 180 | + columns = string::Split(line, '\t'); |
| 181 | + |
| 182 | + if (columns.size() != 5) { |
| 183 | + VLOG(0) << "unexpected line: " << line << ", skip it"; |
| 184 | + continue; |
| 185 | + } |
| 186 | + |
| 187 | + values_str = string::Split(columns[4], ','); |
| 188 | + queue->Push(values_str); |
| 189 | + |
| 190 | + if (count >= batch) { |
| 191 | + break; |
| 192 | + } |
| 193 | + } |
| 194 | + queue->Push({}); |
| 195 | + }; |
| 196 | + |
| 197 | + auto write = [embedding_dim, &out, &queue]() { |
| 198 | + std::vector<std::string> values_str; |
| 199 | + std::string line; |
| 200 | + |
| 201 | + while (true) { |
| 202 | + queue->Pop(&values_str); |
| 203 | + |
| 204 | + if (values_str.size() == 0) { |
| 205 | + break; |
| 206 | + } |
| 207 | + |
| 208 | + for (int x = 0; x < embedding_dim; ++x) { |
| 209 | + float v = 0.0; |
| 210 | + try { |
| 211 | + v = lexical_cast<float>(values_str[x]); |
| 212 | + } catch (boost::bad_lexical_cast &e) { |
| 213 | + VLOG(0) << " get unexpected line: " << line; |
| 214 | + } |
| 215 | + out->push_back(v); |
| 216 | + } |
| 217 | + } |
| 218 | + }; |
| 219 | + |
| 220 | + std::thread p_read(read); |
| 221 | + std::thread p_write(write); |
| 222 | + p_read.join(); |
| 223 | + p_write.join(); |
| 224 | + } |
| 225 | + |
| 226 | + void SerializeVecToStream(std::ostream &out, |
| 227 | + const std::vector<float> &value) { |
| 228 | + out.write(reinterpret_cast<const char *>(value.data()), |
| 229 | + static_cast<std::streamsize>(sizeof(float) * value.size())); |
| 230 | + } |
| 231 | + |
| 232 | + void SerializeValueToStream( |
| 233 | + std::ostream &out, const std::vector<std::string> &ins, |
| 234 | + const std::vector<std::vector<int>> &batch_buckets, |
| 235 | + const int embedding_dim) { |
| 236 | + std::vector<std::shared_ptr<std::ifstream>> in_streams; |
| 237 | + |
| 238 | + for (int x = 0; x < ins.size(); ++x) { |
| 239 | + in_streams.emplace_back(std::make_shared<std::ifstream>(ins[x])); |
| 240 | + } |
| 241 | + |
| 242 | + std::vector<std::future<int>> tasks(ins.size()); |
| 243 | + |
| 244 | + for (int b = 0; b < BUCKET; ++b) { |
| 245 | + std::vector<std::vector<float>> values; |
| 246 | + values.resize(tasks.size()); |
| 247 | + |
| 248 | + auto begin = GetCurrentUS(); |
| 249 | + |
| 250 | + for (int x = 0; x < tasks.size(); ++x) { |
| 251 | + auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b]; |
| 252 | + values[x].clear(); |
| 253 | + values[x].reserve(batch * embedding_dim); |
| 254 | + } |
| 255 | + |
| 256 | + for (int x = 0; x < tasks.size(); ++x) { |
| 257 | + tasks[x] = |
| 258 | + pool_->enqueue([this, b, x, &out, &in_streams, &batch_buckets, |
| 259 | + &values, embedding_dim]() -> int { |
| 260 | + auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b]; |
| 261 | + if (batch == 0) return 0; |
| 262 | + SerializeValueToVec(*(in_streams[x].get()), batch, embedding_dim, |
| 263 | + &values[x]); |
| 264 | + return 0; |
| 265 | + }); |
| 266 | + } |
| 267 | + |
| 268 | + for (size_t x = 0; x < tasks.size(); ++x) { |
| 269 | + tasks[x].wait(); |
| 270 | + } |
| 271 | + |
| 272 | + auto end = GetCurrentUS(); |
| 273 | + |
| 274 | + auto begin1 = GetCurrentUS(); |
| 275 | + for (size_t x = 0; x < tasks.size(); ++x) { |
| 276 | + SerializeVecToStream(out, values[x]); |
| 277 | + } |
| 278 | + auto end1 = GetCurrentUS(); |
| 279 | + |
| 280 | + VLOG(0) << "serialize buckets " << b << " read using " << end - begin |
| 281 | + << ", to oostream using " << end1 - begin1; |
| 282 | + } |
| 283 | + } |
| 284 | + |
| 285 | + void DeserializeRowsFromFile(const std::string &input_file, |
| 286 | + const int64_t feasigns, |
| 287 | + std::vector<int64_t> *rows) { |
| 288 | + std::string line; |
| 289 | + std::vector<std::string> columns; |
| 290 | + std::ifstream file(input_file); |
| 291 | + |
| 292 | + rows->reserve(feasigns); |
| 293 | + |
| 294 | + while (std::getline(file, line)) { |
| 295 | + columns = string::Split(line, '\t'); |
| 296 | + if (columns.size() != 5) { |
| 297 | + VLOG(0) << "unexpected line: " << line << ", skip it"; |
| 298 | + continue; |
| 299 | + } |
| 300 | + rows->push_back(std::stoull(columns[0])); |
| 301 | + } |
| 302 | + |
| 303 | + VLOG(0) << "parse " << rows->size() << " embedding rows from " |
| 304 | + << input_file; |
| 305 | + } |
| 306 | + |
| 307 | + private: |
| 308 | + std::unique_ptr<::ThreadPool> pool_; |
| 309 | +}; |
| 310 | +} // namespace distributed |
| 311 | +} // namespace paddle |
0 commit comments