From 2cc3160b80bec1f4c8a9f649d39cd324cee062f7 Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Thu, 20 May 2021 21:01:14 +0800 Subject: [PATCH 1/8] support ssd in PsCore --- cmake/external/rocksdb.cmake | 45 +++ cmake/third_party.cmake | 3 + paddle/fluid/distributed/fleet.cc | 11 +- paddle/fluid/distributed/fleet.h | 2 +- .../distributed/service/ps_local_client.cc | 11 +- .../distributed/service/ps_local_server.h | 7 +- paddle/fluid/distributed/service/server.h | 2 +- paddle/fluid/distributed/table/CMakeLists.txt | 7 +- .../distributed/table/common_sparse_table.cc | 95 +---- .../distributed/table/common_sparse_table.h | 91 ++++- .../table/depends/large_scale_kv.h | 32 +- .../table/depends/rocksdb_warpper.h | 174 +++++++++ .../distributed/table/ssd_sparse_table.cc | 359 ++++++++++++++++++ .../distributed/table/ssd_sparse_table.h | 56 +++ paddle/fluid/distributed/table/table.cc | 2 + paddle/fluid/operators/lookup_table_op.cc | 5 + paddle/fluid/pybind/fleet_py.cc | 2 + python/paddle/distributed/fleet/__init__.py | 1 + .../distributed/fleet/base/fleet_base.py | 23 ++ .../distributed/fleet/runtime/the_one_ps.py | 25 +- python/paddle/fluid/contrib/layers/nn.py | 11 +- .../fleet/parameter_server/ir/trainer_pass.py | 34 ++ 22 files changed, 894 insertions(+), 104 deletions(-) create mode 100644 cmake/external/rocksdb.cmake create mode 100644 paddle/fluid/distributed/table/depends/rocksdb_warpper.h create mode 100644 paddle/fluid/distributed/table/ssd_sparse_table.cc create mode 100644 paddle/fluid/distributed/table/ssd_sparse_table.h diff --git a/cmake/external/rocksdb.cmake b/cmake/external/rocksdb.cmake new file mode 100644 index 00000000000000..b73d3626a43c37 --- /dev/null +++ b/cmake/external/rocksdb.cmake @@ -0,0 +1,45 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb) +SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb) +SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE) +SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE) +INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR}) + +ExternalProject_Add( + extern_rocksdb + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${ROCKSDB_SOURCES_DIR} + GIT_REPOSITORY "https://github.com/facebook/rocksdb" + GIT_TAG v5.1.4 + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND CXXFLAGS=-fPIC make static_lib + INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/ + && cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES} + && cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/ + BUILD_IN_SOURCE 1 +) + +ADD_DEPENDENCIES(extern_rocksdb snappy) + +ADD_LIBRARY(rocksdb STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET rocksdb PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES}) +ADD_DEPENDENCIES(rocksdb extern_rocksdb) + +LIST(APPEND external_project_dependencies rocksdb) + diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 56edaff2a50dab..0dd7a86df26573 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -304,6 +304,9 @@ if (WITH_PSCORE) include(external/libmct) # download, build, install libmct list(APPEND third_party_deps extern_libmct) + + include(external/rocksdb) # download, build, install libmct + list(APPEND third_party_deps extern_rocksdb) endif() if(WITH_XBYAK) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index dfd55f16e1a065..9e2a0b35224a4e 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( return; } -void FleetWrapper::LoadModel(const std::string& path, const int mode) { - auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); +void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->load(path, mode); + // auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; @@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); auto ret = - pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); + communicator->_worker_ptr->load(table_id, path, std::to_string(mode)); + // auto ret = + // pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 0da5d1e2bf987f..1b2bde85de04c2 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -200,7 +200,7 @@ class FleetWrapper { void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff - void LoadModel(const std::string& path, const int mode); + void LoadModel(const std::string& path, const std::string& mode); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff void LoadModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/distributed/service/ps_local_client.cc b/paddle/fluid/distributed/service/ps_local_client.cc index 2acc845a50890b..e949b21b02e6d9 100644 --- a/paddle/fluid/distributed/service/ps_local_client.cc +++ b/paddle/fluid/distributed/service/ps_local_client.cc @@ -42,17 +42,17 @@ ::std::future PsLocalClient::shrink(uint32_t table_id, ::std::future PsLocalClient::load(const std::string& epoch, const std::string& mode) { // TODO - // for (auto& it : _table_map) { - // load(it.first, epoch, mode); - //} + for (auto& it : _table_map) { + load(it.first, epoch, mode); + } return done(); } ::std::future PsLocalClient::load(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - // auto* table_ptr = table(table_id); - // table_ptr->load(epoch, mode); + auto* table_ptr = table(table_id); + table_ptr->load(epoch, mode); return done(); } @@ -245,7 +245,6 @@ ::std::future PsLocalClient::pull_sparse_ptr(char** select_values, ::std::future PsLocalClient::push_sparse_raw_gradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) { - VLOG(1) << "wxx push_sparse_raw_gradient"; PSClientClosure* closure = reinterpret_cast(callback); auto* accessor = table_accessor(table_id); auto* table_ptr = table(table_id); diff --git a/paddle/fluid/distributed/service/ps_local_server.h b/paddle/fluid/distributed/service/ps_local_server.h index dfbccc70900e3c..33b0b5fa796d75 100644 --- a/paddle/fluid/distributed/service/ps_local_server.h +++ b/paddle/fluid/distributed/service/ps_local_server.h @@ -26,9 +26,14 @@ class PsLocalServer : public PSServer { PsLocalServer() {} virtual ~PsLocalServer() {} virtual uint64_t start() { return 0; } - virtual uint64_t start(const std::string& ip, uint32_t port) { return 0; } + virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual int32_t stop() { return 0; } virtual int32_t port() { return 0; } + virtual int32_t configure( + const PSParameter &config, PSEnvironment &env, size_t server_rank, + const std::vector &server_sub_program = {}) { + return 0; + } private: virtual int32_t initialize() { return 0; } diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h index 74a8cbe44b144b..89b089386f5018 100644 --- a/paddle/fluid/distributed/service/server.h +++ b/paddle/fluid/distributed/service/server.h @@ -70,7 +70,7 @@ class PSServer { virtual int32_t configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, - const std::vector &server_sub_program = {}) final; + const std::vector &server_sub_program = {}); // return server_ip virtual std::string ip() { return butil::my_ip_cstr(); } diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index dab390958034af..443fa928d1268f 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -9,15 +9,17 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $ cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) -cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc +cc_library(common_table SRCS common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} -${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) +${RPC_DEPS} graph_edge graph_node device_context string_helper +simple_threadpool xxhash generator rocksdb) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) @@ -27,3 +29,4 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) +target_link_libraries(table -lbz2) diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index a4f672c2963a84..6867e49f73df4b 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -25,83 +25,12 @@ class ValueBlock; } // namespace distributed } // namespace paddle -#define PSERVER_SAVE_SUFFIX ".shard" -using boost::lexical_cast; - namespace paddle { namespace distributed { -enum SaveMode { all, base, delta }; - -struct Meta { - std::string param; - int shard_id; - std::vector names; - std::vector dims; - uint64_t count; - std::unordered_map dims_map; - - explicit Meta(const std::string& metapath) { - std::ifstream file(metapath); - std::string line; - int num_lines = 0; - while (std::getline(file, line)) { - if (StartWith(line, "#")) { - continue; - } - auto pairs = paddle::string::split_string(line, "="); - PADDLE_ENFORCE_EQ( - pairs.size(), 2, - paddle::platform::errors::InvalidArgument( - "info in %s except k=v, but got %s", metapath, line)); - - if (pairs[0] == "param") { - param = pairs[1]; - } - if (pairs[0] == "shard_id") { - shard_id = std::stoi(pairs[1]); - } - if (pairs[0] == "row_names") { - names = paddle::string::split_string(pairs[1], ","); - } - if (pairs[0] == "row_dims") { - auto dims_strs = - paddle::string::split_string(pairs[1], ","); - for (auto& str : dims_strs) { - dims.push_back(std::stoi(str)); - } - } - if (pairs[0] == "count") { - count = std::stoull(pairs[1]); - } - } - for (int x = 0; x < names.size(); ++x) { - dims_map[names[x]] = dims[x]; - } - } - - Meta(std::string param, int shard_id, std::vector row_names, - std::vector dims, uint64_t count) { - this->param = param; - this->shard_id = shard_id; - this->names = row_names; - this->dims = dims; - this->count = count; - } - - std::string ToString() { - std::stringstream ss; - ss << "param=" << param << "\n"; - ss << "shard_id=" << shard_id << "\n"; - ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n"; - ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n"; - ss << "count=" << count << "\n"; - return ss.str(); - } -}; - -void ProcessALine(const std::vector& columns, const Meta& meta, - const int64_t id, std::vector>* values) { +void CommonSparseTable::ProcessALine(const std::vector& columns, + const Meta& meta, const int64_t id, + std::vector>* values) { auto colunmn_size = columns.size(); auto load_values = paddle::string::split_string(columns[colunmn_size - 1], ","); @@ -134,8 +63,9 @@ void ProcessALine(const std::vector& columns, const Meta& meta, } } -int64_t SaveToText(std::ostream* os, std::shared_ptr block, - const int mode) { +int64_t CommonSparseTable::SaveToText(std::ostream* os, + std::shared_ptr block, + const int mode, int shard_id) { int64_t save_num = 0; for (auto& table : block->values_) { @@ -173,10 +103,10 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr block, return save_num; } -int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, - const int pserver_id, const int pserver_num, - const int local_shard_num, - std::vector>* blocks) { +int64_t CommonSparseTable::LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks) { Meta meta = Meta(metapath); int num_lines = 0; @@ -185,7 +115,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, while (std::getline(file, line)) { auto values = paddle::string::split_string(line, "\t"); - auto id = lexical_cast(values[0]); + auto id = lexical_cast(values[0]); if (id % pserver_num != pserver_id) { VLOG(3) << "will not load " << values[0] << " from " << valuepath @@ -366,7 +296,8 @@ int32_t CommonSparseTable::save(const std::string& dirname, int64_t total_ins = 0; for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { // save values - total_ins += SaveToText(value_out.get(), shard_values_[shard_id], mode); + total_ins += + SaveToText(value_out.get(), shard_values_[shard_id], mode, shard_id); } value_out->close(); diff --git a/paddle/fluid/distributed/table/common_sparse_table.h b/paddle/fluid/distributed/table/common_sparse_table.h index 50c295da53464c..439e261af64923 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.h +++ b/paddle/fluid/distributed/table/common_sparse_table.h @@ -32,11 +32,83 @@ #include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/string/string_helper.h" +#define PSERVER_SAVE_SUFFIX ".shard" +using boost::lexical_cast; + namespace paddle { namespace distributed { class SparseOptimizer; +enum SaveMode { all, base, delta }; + +struct Meta { + std::string param; + int shard_id; + std::vector names; + std::vector dims; + uint64_t count; + std::unordered_map dims_map; + + explicit Meta(const std::string& metapath) { + std::ifstream file(metapath); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + if (StartWith(line, "#")) { + continue; + } + auto pairs = paddle::string::split_string(line, "="); + PADDLE_ENFORCE_EQ( + pairs.size(), 2, + paddle::platform::errors::InvalidArgument( + "info in %s except k=v, but got %s", metapath, line)); + + if (pairs[0] == "param") { + param = pairs[1]; + } + if (pairs[0] == "shard_id") { + shard_id = std::stoi(pairs[1]); + } + if (pairs[0] == "row_names") { + names = paddle::string::split_string(pairs[1], ","); + } + if (pairs[0] == "row_dims") { + auto dims_strs = + paddle::string::split_string(pairs[1], ","); + for (auto& str : dims_strs) { + dims.push_back(std::stoi(str)); + } + } + if (pairs[0] == "count") { + count = std::stoull(pairs[1]); + } + } + for (int x = 0; x < names.size(); ++x) { + dims_map[names[x]] = dims[x]; + } + } + + Meta(std::string param, int shard_id, std::vector row_names, + std::vector dims, uint64_t count) { + this->param = param; + this->shard_id = shard_id; + this->names = row_names; + this->dims = dims; + this->count = count; + } + + std::string ToString() { + std::stringstream ss; + ss << "param=" << param << "\n"; + ss << "shard_id=" << shard_id << "\n"; + ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n"; + ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n"; + ss << "count=" << count << "\n"; + return ss.str(); + } +}; + class CommonSparseTable : public SparseTable { public: CommonSparseTable() { rwlock_.reset(new framework::RWLock); } @@ -56,9 +128,22 @@ class CommonSparseTable : public SparseTable { virtual int32_t initialize_optimizer(); virtual int32_t initialize_recorder(); - int32_t load(const std::string& path, const std::string& param); + virtual int32_t load(const std::string& path, const std::string& param); + + virtual int32_t save(const std::string& path, const std::string& param); - int32_t save(const std::string& path, const std::string& param); + virtual int64_t SaveToText(std::ostream* os, + std::shared_ptr block, const int mode, + int shard_id); + + virtual void ProcessALine(const std::vector& columns, + const Meta& meta, const int64_t id, + std::vector>* values); + + virtual int64_t LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks); virtual std::pair print_table_stat(); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); @@ -89,7 +174,7 @@ class CommonSparseTable : public SparseTable { virtual int32_t _push_sparse(const uint64_t* keys, const float** values, size_t num); - private: + protected: const int task_pool_size_ = 11; std::vector> _shards_task_pool; diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index 5c10fca98cda4d..ac11183d192fff 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -83,6 +83,7 @@ inline bool probility_entry(VALUE *value, float threshold) { class ValueBlock { public: + typedef typename robin_hood::unordered_map map_type; explicit ValueBlock(const std::vector &value_names, const std::vector &value_dims, const std::vector &value_offsets, @@ -261,6 +262,18 @@ class ValueBlock { value->is_entry_ = state; } + void erase(uint64_t feasign) { + size_t hash = _hasher(feasign); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + auto iter = table.find(feasign); + if (iter != table.end()) { + butil::return_object(iter->second); + iter = table.erase(iter); + } + } + void Shrink(const int threshold) { for (auto &table : values_) { for (auto iter = table.begin(); iter != table.end();) { @@ -289,6 +302,23 @@ class ValueBlock { } } + map_type::iterator end() { + return values_[SPARSE_SHARD_BUCKET_NUM - 1].end(); + } + + map_type::iterator Find(uint64_t id) { + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + auto got = table.find(id); + if (got == table.end()) { + return end(); + } else { + return got; + } + } + private: bool Has(const uint64_t id) { size_t hash = _hasher(id); @@ -304,7 +334,7 @@ class ValueBlock { } public: - robin_hood::unordered_map values_[SPARSE_SHARD_BUCKET_NUM]; + map_type values_[SPARSE_SHARD_BUCKET_NUM]; size_t value_length_ = 0; std::hash _hasher; diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h new file mode 100644 index 00000000000000..d010ccfe1b8e83 --- /dev/null +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -0,0 +1,174 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +class RocksDBHandler { + public: + RocksDBHandler() {} + ~RocksDBHandler() {} + + static RocksDBHandler* GetInstance() { + static RocksDBHandler handler; + return &handler; + } + + int initialize(const std::string& db_path, const int colnum) { + VLOG(3) << "db path: " << db_path << " colnum: " << colnum; + rocksdb::Options options; + rocksdb::BlockBasedTableOptions bbto; + bbto.block_size = 4 * 1024; + bbto.block_cache = rocksdb::NewLRUCache(64 * 1024 * 1024); + bbto.block_cache_compressed = rocksdb::NewLRUCache(64 * 1024 * 1024); + bbto.cache_index_and_filter_blocks = false; + bbto.filter_policy.reset(rocksdb::NewBloomFilterPolicy(20, false)); + bbto.whole_key_filtering = true; + options.table_factory.reset(rocksdb::NewBlockBasedTableFactory(bbto)); + VLOG(3) << "wxx aaa "; + + // options.IncreaseParallelism(); + // options.OptimizeLevelStyleCompaction(); + options.keep_log_file_num = 100; + // options.db_log_dir = "./log/rocksdb"; + options.max_log_file_size = 50 * 1024 * 1024; // 50MB + // options.threads = 8; + options.create_if_missing = true; + options.use_direct_reads = true; + options.use_direct_writes = true; + options.max_background_flushes = 5; + options.max_background_compactions = 5; + options.base_background_compactions = 10; + options.write_buffer_size = 256 * 1024 * 1024; // 256MB + options.max_write_buffer_number = 8; + options.max_bytes_for_level_base = + options.max_write_buffer_number * options.write_buffer_size; + options.min_write_buffer_number_to_merge = 1; + options.target_file_size_base = 1024 * 1024 * 1024; // 1024MB + // options.verify_checksums_in_compaction = false; + // options.disable_auto_compactions = true; + options.memtable_prefix_bloom_size_ratio = 0.02; + options.num_levels = 4; + options.max_open_files = -1; + VLOG(3) << "wxx bbb "; + + options.compression = rocksdb::kNoCompression; + // options.compaction_options_fifo = rocksdb::CompactionOptionsFIFO(); + // options.compaction_style = + // rocksdb::CompactionStyle::kCompactionStyleFIFO; + options.level0_file_num_compaction_trigger = 8; + options.level0_slowdown_writes_trigger = + 1.8 * options.level0_file_num_compaction_trigger; + options.level0_stop_writes_trigger = + 3.6 * options.level0_file_num_compaction_trigger; + + VLOG(3) << "wxx ccc "; + if (!db_path.empty()) { + std::string rm_cmd = "rm -rf " + db_path; + system(rm_cmd.c_str()); + } + VLOG(3) << "wxx ddd "; + + rocksdb::Status s = rocksdb::DB::Open(options, db_path, &_db); + VLOG(3) << "wxx eee "; + assert(s.ok()); + _handles.resize(colnum); + VLOG(3) << "wxx ahaha "; + for (int i = 0; i < colnum; i++) { + s = _db->CreateColumnFamily(options, "shard_" + std::to_string(i), + &_handles[i]); + VLOG(3) << "wxx hihihi "; + assert(s.ok()); + } + VLOG(3) << "wxx fff "; + LOG(INFO) << "DB initialize success, colnum:" << colnum; + return 0; + } + + int put(int id, const char* key, int key_len, const char* value, + int value_len) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::Status s = + _db->Put(options, _handles[id], rocksdb::Slice(key, key_len), + rocksdb::Slice(value, value_len)); + assert(s.ok()); + return 0; + } + + int put_batch(int id, std::vector>& ssd_keys, + std::vector>& ssd_values, int n) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::WriteBatch batch(n * 128); + for (int i = 0; i < n; i++) { + batch.Put(_handles[id], + rocksdb::Slice(ssd_keys[i].first, ssd_keys[i].second), + rocksdb::Slice(ssd_values[i].first, ssd_values[i].second)); + } + rocksdb::Status s = _db->Write(options, &batch); + assert(s.ok()); + return 0; + } + + int get(int id, const char* key, int key_len, std::string& value) { + rocksdb::Status s = _db->Get(rocksdb::ReadOptions(), _handles[id], + rocksdb::Slice(key, key_len), &value); + if (s.IsNotFound()) { + return 1; + } + assert(s.ok()); + return 0; + } + + int del_data(int id, const char* key, int key_len) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::Status s = + _db->Delete(options, _handles[id], rocksdb::Slice(key, key_len)); + assert(s.ok()); + return 0; + } + + int flush(int id) { + rocksdb::Status s = _db->Flush(rocksdb::FlushOptions(), _handles[id]); + assert(s.ok()); + return 0; + } + + rocksdb::Iterator* get_iterator(int id) { + return _db->NewIterator(rocksdb::ReadOptions(), _handles[id]); + } + + int get_estimate_key_num(uint64_t& num_keys) { + _db->GetAggregatedIntProperty("rocksdb.estimate-num-keys", &num_keys); + return 0; + } + + private: + std::vector _handles; + rocksdb::DB* _db; +}; +} +} diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.cc b/paddle/fluid/distributed/table/ssd_sparse_table.cc new file mode 100644 index 00000000000000..ab41d43c750c4b --- /dev/null +++ b/paddle/fluid/distributed/table/ssd_sparse_table.cc @@ -0,0 +1,359 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/ssd_sparse_table.h" + +DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); + +namespace paddle { +namespace distributed { + +int32_t SSDSparseTable::initialize() { + _shards_task_pool.resize(task_pool_size_); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + + sync = _config.common().sync(); + VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; + + _global_lr = new float(1.0); + + auto common = _config.common(); + int size = static_cast(common.params().size()); + + size_t offset = 0; + for (int x = 0; x < size; ++x) { + auto& varname = common.params()[x]; + auto& dim = common.dims()[x]; + + value_idx_[varname] = x; + value_names_.push_back(varname); + value_dims_.push_back(dim); + value_offsets_.push_back(offset); + initializer_attrs_.push_back(common.initializers()[x]); + + if (varname == "Param") { + param_dim_ = dim; + param_offset_ = offset; + } + + offset += dim; + } + + initialize_value(); + initialize_optimizer(); + initialize_recorder(); + _db = paddle::distributed::RocksDBHandler::GetInstance(); + _db->initialize(FLAGS_rocksdb_path, task_pool_size_); + return 0; +} + +int32_t SSDSparseTable::pull_sparse(float* pull_values, + const PullSparseValue& pull_value) { + auto shard_num = task_pool_size_; + std::vector> tasks(shard_num); + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, shard_num, &pull_value, &pull_values]() -> int { + auto& block = shard_values_[shard_id]; + + std::vector offsets; + pull_value.Fission(shard_id, shard_num, &offsets); + + for (auto& offset : offsets) { + auto feasign = pull_value.feasigns_[offset]; + auto frequencie = pull_value.frequencies_[offset]; + float* embedding = nullptr; + auto iter = block->Find(feasign); + // in mem + if (iter == block->end()) { + embedding = iter->second->data_.data(); + if (pull_value.is_training_) { + block->AttrUpdate(iter->second, frequencie); + } + } else { + // need create + std::string tmp_str(""); + if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t), + tmp_str) > 0) { + embedding = block->Init(feasign, true, frequencie); + } else { + // in db + int data_size = tmp_str.size() / sizeof(float); + int value_size = block->value_length_; + float* db_value = (float*)const_cast(tmp_str.c_str()); + VALUE* value = block->InitGet(feasign); + + // copy to mem + memcpy(value->data_.data(), db_value, + value_size * sizeof(float)); + embedding = db_value; + + // param, count, unseen_day + value->count_ = db_value[value_size]; + value->unseen_days_ = db_value[value_size + 1]; + value->is_entry_ = db_value[value_size + 2]; + if (pull_value.is_training_) { + block->AttrUpdate(value, frequencie); + } + } + } + std::copy_n(embedding + param_offset_, param_dim_, + pull_values + param_dim_ * offset); + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, + const uint64_t* keys, size_t num) { + auto shard_num = task_pool_size_; + std::vector> tasks(shard_num); + + std::vector> offset_bucket; + offset_bucket.resize(task_pool_size_); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % task_pool_size_; + offset_bucket[y].push_back(x); + } + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &pull_values, &offset_bucket]() -> int { + auto& block = shard_values_[shard_id]; + auto& offsets = offset_bucket[shard_id]; + + for (auto& offset : offsets) { + auto feasign = keys[offset]; + auto iter = block->Find(feasign); + VALUE* value = nullptr; + // in mem + if (iter != block->end()) { + value = iter->second; + } else { + // need create + std::string tmp_str(""); + if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t), + tmp_str) > 0) { + value = block->InitGet(feasign); + } else { + // in db + int data_size = tmp_str.size() / sizeof(float); + int value_size = block->value_length_; + float* db_value = (float*)const_cast(tmp_str.c_str()); + value = block->InitGet(feasign); + + // copy to mem + memcpy(value->data_.data(), db_value, + value_size * sizeof(float)); + + // param, count, unseen_day + value->count_ = db_value[value_size]; + value->unseen_days_ = db_value[value_size + 1]; + value->is_entry_ = db_value[value_size + 2]; + } + } + pull_values[offset] = (char*)value; + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t SSDSparseTable::shrink(const std::string& param) { return 0; } + +int32_t SSDSparseTable::update_table() { + int count = 0; + int value_size = shard_values_[0]->value_length_; + int db_size = 3 + value_size; + float tmp_value[db_size]; + + for (size_t i = 0; i < task_pool_size_; ++i) { + auto& block = shard_values_[i]; + + for (auto& table : block->values_) { + for (auto iter = table.begin(); iter != table.end();) { + VALUE* value = iter->second; + if (value->unseen_days_ >= 1) { + tmp_value[value_size] = value->count_; + tmp_value[value_size + 1] = value->unseen_days_; + tmp_value[value_size + 2] = value->is_entry_; + memcpy(tmp_value, value->data_.data(), sizeof(float) * value_size); + _db->put(i, (char*)&(iter->first), sizeof(uint64_t), (char*)tmp_value, + db_size * sizeof(float)); + count++; + + butil::return_object(iter->second); + iter = table.erase(iter); + } else { + ++iter; + } + } + } + _db->flush(i); + } + VLOG(1) << "Table>> update count: " << count; + return 0; +} + +int64_t SSDSparseTable::SaveToText(std::ostream* os, + std::shared_ptr block, + const int mode, int shard_id) { + int64_t save_num = 0; + + for (auto& table : block->values_) { + for (auto& value : table) { + if (mode == SaveMode::delta && !value.second->need_save_) { + continue; + } + + ++save_num; + + std::stringstream ss; + auto* vs = value.second->data_.data(); + + auto id = value.first; + + ss << id << "\t" << value.second->count_ << "\t" + << value.second->unseen_days_ << "\t" << value.second->is_entry_ + << "\t"; + + for (int i = 0; i < block->value_length_ - 1; i++) { + ss << std::to_string(vs[i]) << ","; + } + + ss << std::to_string(vs[block->value_length_ - 1]); + ss << "\n"; + + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + + if (mode == SaveMode::base || mode == SaveMode::delta) { + value.second->need_save_ = false; + } + } + } + + if (mode != 1) { + int value_size = block->value_length_; + auto* it = _db->get_iterator(shard_id); + + for (it->SeekToFirst(); it->Valid(); it->Next()) { + float* value = (float*)const_cast(it->value().data()); + std::stringstream ss; + ss << *((uint64_t*)const_cast(it->key().data())) << "\t" + << value[value_size] << "\t" << value[value_size + 1] << "\t" + << value[value_size + 2] << "\t"; + for (int i = 0; i < block->value_length_ - 1; i++) { + ss << std::to_string(value[i]) << ","; + } + + ss << std::to_string(value[block->value_length_ - 1]); + ss << "\n"; + + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + } + } + + return save_num; +} + +int32_t SSDSparseTable::load(const std::string& path, + const std::string& param) { + rwlock_->WRLock(); + VLOG(3) << "ssd sparse table load with " << path << " with meta " << param; + LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, + &shard_values_); + rwlock_->UNLock(); + return 0; +} + +int64_t SSDSparseTable::LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks) { + Meta meta = Meta(metapath); + + int num_lines = 0; + std::ifstream file(valuepath); + std::string line; + + int value_size = shard_values_[0]->value_length_; + int db_size = 3 + value_size; + float tmp_value[db_size]; + + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + auto id = lexical_cast(values[0]); + + if (id % pserver_num != pserver_id) { + VLOG(3) << "will not load " << values[0] << " from " << valuepath + << ", please check id distribution"; + continue; + } + + auto shard_id = id % local_shard_num; + auto block = blocks->at(shard_id); + + std::vector> kvalues; + ProcessALine(values, meta, id, &kvalues); + + block->Init(id, false); + + VALUE* value_instant = block->GetValue(id); + + if (values.size() == 5) { + value_instant->count_ = lexical_cast(values[1]); + value_instant->unseen_days_ = lexical_cast(values[2]); + value_instant->is_entry_ = + static_cast(lexical_cast(values[3])); + } + + std::vector block_values = block->Get(id, meta.names, meta.dims); + auto blas = GetBlas(); + for (int x = 0; x < meta.names.size(); ++x) { + blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]); + } + VLOG(3) << "loading: " << id + << "unseen day: " << value_instant->unseen_days_; + if (value_instant->unseen_days_ >= 1) { + tmp_value[value_size] = value_instant->count_; + tmp_value[value_size + 1] = value_instant->unseen_days_; + tmp_value[value_size + 2] = value_instant->is_entry_; + memcpy(tmp_value, value_instant->data_.data(), + sizeof(float) * value_size); + _db->put(shard_id, (char*)&(id), sizeof(uint64_t), (char*)tmp_value, + db_size * sizeof(float)); + block->erase(id); + } + } + + return 0; +} + +} // namespace ps +} // namespace paddle diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.h b/paddle/fluid/distributed/table/ssd_sparse_table.h new file mode 100644 index 00000000000000..796dd55e2e0804 --- /dev/null +++ b/paddle/fluid/distributed/table/ssd_sparse_table.h @@ -0,0 +1,56 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/depends/rocksdb_warpper.h" + +namespace paddle { +namespace distributed { +class SSDSparseTable : public CommonSparseTable { + public: + SSDSparseTable() {} + virtual ~SSDSparseTable() {} + + virtual int32_t initialize() override; + virtual int64_t SaveToText(std::ostream* os, + std::shared_ptr block, const int mode, + int shard_id); + + virtual int64_t LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks); + + virtual int32_t load(const std::string& path, const std::string& param); + + // exchange data + virtual int32_t update_table(); + + virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + + virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, + size_t num); + + virtual int32_t flush() override { return 0; } + virtual int32_t shrink(const std::string& param) override; + virtual void clear() override {} + + private: + RocksDBHandler* _db; + int64_t _cache_tk_size; +}; + +} // namespace ps +} // namespace paddle diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc index 600be954cb5966..25884387aaecf9 100644 --- a/paddle/fluid/distributed/table/table.cc +++ b/paddle/fluid/distributed/table/table.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/ssd_sparse_table.h" #include "paddle/fluid/distributed/table/tensor_accessor.h" #include "paddle/fluid/distributed/table/tensor_table.h" @@ -29,6 +30,7 @@ namespace distributed { REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable); +REGISTER_PSCORE_CLASS(Table, SSDSparseTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, TensorTable); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 2e8b551ea4e43c..9a0ce3900acf1c 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -118,6 +118,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ") for entry attribute.") .SetDefault("none"); + AddAttr("table_class", + "(std::string, default " + ") for table_class.") + .SetDefault("none"); + AddAttr>( "table_names", "(string vector, the split table names that will be fetched from " diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 91461aa26f341a..b8b31004881660 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -56,6 +56,8 @@ void BindDistFleetWrapper(py::module* m) { "DistFleetWrapper") .def(py::init([]() { return FleetWrapper::GetInstance(); })) .def("load_sparse", &FleetWrapper::LoadSparseOnServer) + .def("load_model", &FleetWrapper::LoadModel) + .def("load_one_table", &FleetWrapper::LoadModelOneTable) .def("init_server", &FleetWrapper::InitServer) .def("run_server", (uint64_t (FleetWrapper::*)(void)) & FleetWrapper::RunServer) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 5f9a61371d34f4..3186df7db581a5 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -77,6 +77,7 @@ distributed_optimizer = fleet.distributed_optimizer save_inference_model = fleet.save_inference_model save_persistables = fleet.save_persistables +load_model = fleet.load_model minimize = fleet.minimize distributed_model = fleet.distributed_model step = fleet.step diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 5e883f1ac6cc91..9e5a31d6899e07 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -540,6 +540,29 @@ def init_server(self, *args, **kwargs): """ self._runtime_handle._init_server(*args, **kwargs) + def load_model(self, path, mode): + """ + load fleet model from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + + # build net + # fleet.distributed_optimizer(...) + + fleet.load_model("path", "mode") + + """ + self._runtime_handle.load_model(path, mode) + @is_non_distributed_check @inited_runtime_handler def run_server(self): diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index d31fa549ad5623..59db687143163e 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -35,6 +35,20 @@ def conv_indent(indent): PSERVER_SAVE_SUFFIX = ".shard" +def parse_table_class(varname, o_main_program): + from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_distributed_sparse_op + from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_sparse_op + + for op in o_main_program.global_block().ops: + if not is_distributed_sparse_op(op) and not is_sparse_op(op): + continue + + param_name = op.input("W")[0] + + if param_name == varname and op.type == "lookup_table": + return op.attr('table_class') + + class Accessor: def __init__(self): self.accessor_class = "" @@ -723,13 +737,15 @@ def _get_tables(): table.type = "PS_SPARSE_TABLE" table.shard_num = 256 + common.table_name = self.compiled_strategy.grad_name_to_param_name[ + ctx.origin_varnames()[0]] + if self.compiled_strategy.is_geo_mode(): table.table_class = "SparseGeoTable" else: - table.table_class = "CommonSparseTable" + table.table_class = parse_table_class( + common.table_name, self.origin_main_program) - common.table_name = self.compiled_strategy.grad_name_to_param_name[ - ctx.origin_varnames()[0]] else: table.type = "PS_DENSE_TABLE" table.table_class = "CommonDenseTable" @@ -1049,6 +1065,9 @@ def _save_inference_model(self, *args, **kwargs): def _save_persistables(self, *args, **kwargs): self._ps_inference_save_persistables(*args, **kwargs) + def load_model(self, path, mode): + self._worker.load_model(path, mode) + def _shrink(self, threshold): import paddle.distributed.fleet as fleet fleet.util.barrier() diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 8c48033fc46f54..15425c661df1a0 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -967,6 +967,7 @@ def sparse_embedding(input, padding_idx=None, is_test=False, entry=None, + table_class=None, param_attr=None, dtype='float32'): helper = LayerHelper('sparse_embedding', **locals()) @@ -989,6 +990,13 @@ def sparse_embedding(input, padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( size[0] + padding_idx) + table_class_str = "CommonSparseTable" + if table_class is not None: + if table_class not in ["CommonSparseTable", "SSDSparseTable"]: + raise ValueError( + "table_class must be in [CommonSparseTable, SSDSparseTable]") + table_class_str = table_class + entry_str = "none" if entry is not None: @@ -1011,7 +1019,8 @@ def sparse_embedding(input, 'is_distributed': True, 'remote_prefetch': True, 'is_test': is_test, - 'entry': entry_str + 'entry': entry_str, + 'table_class': table_class_str }) return tmp diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index d4af3e2f8042a5..89b2a8237dc65a 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -365,7 +365,41 @@ def _remove_lookup_table_grad_op_and_var(program): for name in remove_var: program.global_block()._remove_var(name) + def _remove_optimizer_var(program): + + embedding_w = {} + for idx, op in list(enumerate(program.global_block().ops)): + if op.type == "lookup_table_grad": + for name in op.input("W"): + embedding_w[name] = 1 + + optimize_vars = [] + optimize_op_role_vars = [] + optimize_need_delete_vars = [] + for op in _get_optimize_ops(program): + for name in op.input("Param"): + if name in embedding_w: + optimize_op_role_vars.extend(op.attr("op_role_var")) + for key_name in op.input_names: + if key_name == "LearningRate": + continue + for var in op.input(key_name): + optimize_vars.append(var) + + optimize_vars = list(set(optimize_vars)) + optimize_op_role_vars = list(set(optimize_op_role_vars)) + + for var in optimize_vars: + if var not in optimize_op_role_vars: + optimize_need_delete_vars.append(var) + need_delete_optimize_vars = list(set(optimize_need_delete_vars)) + + for name in need_delete_optimize_vars: + if program.global_block().has_var(name): + program.global_block()._remove_var(name) + _add_push_box_sparse_op(program) + _remove_optimizer_var(program) _remove_lookup_table_grad_op_and_var(program) return program From 4872ae433c448862feb08ce62d607a5b5e4c5e95 Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Thu, 20 May 2021 21:06:08 +0800 Subject: [PATCH 2/8] remove log --- .../distributed/table/depends/rocksdb_warpper.h | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h index d010ccfe1b8e83..65c1a83edbd818 100644 --- a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -46,14 +46,9 @@ class RocksDBHandler { bbto.filter_policy.reset(rocksdb::NewBloomFilterPolicy(20, false)); bbto.whole_key_filtering = true; options.table_factory.reset(rocksdb::NewBlockBasedTableFactory(bbto)); - VLOG(3) << "wxx aaa "; - // options.IncreaseParallelism(); - // options.OptimizeLevelStyleCompaction(); options.keep_log_file_num = 100; - // options.db_log_dir = "./log/rocksdb"; options.max_log_file_size = 50 * 1024 * 1024; // 50MB - // options.threads = 8; options.create_if_missing = true; options.use_direct_reads = true; options.use_direct_writes = true; @@ -66,42 +61,30 @@ class RocksDBHandler { options.max_write_buffer_number * options.write_buffer_size; options.min_write_buffer_number_to_merge = 1; options.target_file_size_base = 1024 * 1024 * 1024; // 1024MB - // options.verify_checksums_in_compaction = false; - // options.disable_auto_compactions = true; options.memtable_prefix_bloom_size_ratio = 0.02; options.num_levels = 4; options.max_open_files = -1; - VLOG(3) << "wxx bbb "; options.compression = rocksdb::kNoCompression; - // options.compaction_options_fifo = rocksdb::CompactionOptionsFIFO(); - // options.compaction_style = - // rocksdb::CompactionStyle::kCompactionStyleFIFO; options.level0_file_num_compaction_trigger = 8; options.level0_slowdown_writes_trigger = 1.8 * options.level0_file_num_compaction_trigger; options.level0_stop_writes_trigger = 3.6 * options.level0_file_num_compaction_trigger; - VLOG(3) << "wxx ccc "; if (!db_path.empty()) { std::string rm_cmd = "rm -rf " + db_path; system(rm_cmd.c_str()); } - VLOG(3) << "wxx ddd "; rocksdb::Status s = rocksdb::DB::Open(options, db_path, &_db); - VLOG(3) << "wxx eee "; assert(s.ok()); _handles.resize(colnum); - VLOG(3) << "wxx ahaha "; for (int i = 0; i < colnum; i++) { s = _db->CreateColumnFamily(options, "shard_" + std::to_string(i), &_handles[i]); - VLOG(3) << "wxx hihihi "; assert(s.ok()); } - VLOG(3) << "wxx fff "; LOG(INFO) << "DB initialize success, colnum:" << colnum; return 0; } From 9deb0193d592c8cad0949a573334d3f0f8a86d0f Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Fri, 21 May 2021 13:27:34 +0800 Subject: [PATCH 3/8] remove bz2 --- cmake/external/rocksdb.cmake | 12 +++++++++--- paddle/fluid/distributed/table/CMakeLists.txt | 2 +- .../distributed/table/depends/rocksdb_warpper.h | 1 - 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cmake/external/rocksdb.cmake b/cmake/external/rocksdb.cmake index b73d3626a43c37..f5b85cc71a25f1 100644 --- a/cmake/external/rocksdb.cmake +++ b/cmake/external/rocksdb.cmake @@ -18,6 +18,7 @@ SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb) SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb) SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE) SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE) +SET(ROCKSDB_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR}) ExternalProject_Add( @@ -25,10 +26,15 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${ROCKSDB_SOURCES_DIR} GIT_REPOSITORY "https://github.com/facebook/rocksdb" - GIT_TAG v5.1.4 + GIT_TAG v6.10.1 UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_COMMAND CXXFLAGS=-fPIC make static_lib + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DWITH_BZ2=OFF + -DWITH_GFLAGS=OFF + -DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS} + -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} +# BUILD_BYPRODUCTS ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/ && cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES} && cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/ diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index 443fa928d1268f..e00c5594997fe9 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -29,4 +29,4 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) -target_link_libraries(table -lbz2) +#target_link_libraries(table -lbz2) diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h index 65c1a83edbd818..83ecaf0c43c4f5 100644 --- a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -51,7 +51,6 @@ class RocksDBHandler { options.max_log_file_size = 50 * 1024 * 1024; // 50MB options.create_if_missing = true; options.use_direct_reads = true; - options.use_direct_writes = true; options.max_background_flushes = 5; options.max_background_compactions = 5; options.base_background_compactions = 10; From 7d9708fde91935ed409e6f7a5affc7371b33455e Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Fri, 21 May 2021 16:32:12 +0800 Subject: [PATCH 4/8] defalut value --- .../paddle/distributed/fleet/runtime/the_one_ps.py | 7 +++++-- python/paddle/fluid/contrib/layers/nn.py | 13 +++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 59db687143163e..91b70282295135 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -45,8 +45,11 @@ def parse_table_class(varname, o_main_program): param_name = op.input("W")[0] - if param_name == varname and op.type == "lookup_table": - return op.attr('table_class') + if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2": + if op.has_attr('table_class'): + return op.attr('table_class') + else: + return "CommonSparseTable" class Accessor: diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 15425c661df1a0..30316b77adcdfc 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -967,7 +967,7 @@ def sparse_embedding(input, padding_idx=None, is_test=False, entry=None, - table_class=None, + table_class="CommonSparseTable", param_attr=None, dtype='float32'): helper = LayerHelper('sparse_embedding', **locals()) @@ -990,12 +990,9 @@ def sparse_embedding(input, padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( size[0] + padding_idx) - table_class_str = "CommonSparseTable" - if table_class is not None: - if table_class not in ["CommonSparseTable", "SSDSparseTable"]: - raise ValueError( - "table_class must be in [CommonSparseTable, SSDSparseTable]") - table_class_str = table_class + if table_class not in ["CommonSparseTable", "SSDSparseTable"]: + raise ValueError( + "table_class must be in [CommonSparseTable, SSDSparseTable]") entry_str = "none" @@ -1020,7 +1017,7 @@ def sparse_embedding(input, 'remote_prefetch': True, 'is_test': is_test, 'entry': entry_str, - 'table_class': table_class_str + 'table_class': table_class }) return tmp From 3969f83fdc0c551c914030909a7b71cb6e617e6a Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Fri, 21 May 2021 16:33:53 +0800 Subject: [PATCH 5/8] code style --- paddle/fluid/distributed/table/depends/rocksdb_warpper.h | 4 ++-- paddle/fluid/distributed/table/ssd_sparse_table.cc | 4 ++-- paddle/fluid/distributed/table/ssd_sparse_table.h | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h index 83ecaf0c43c4f5..34e23af03279d5 100644 --- a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -1,11 +1,11 @@ // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.cc b/paddle/fluid/distributed/table/ssd_sparse_table.cc index ab41d43c750c4b..098f11173a99e3 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/table/ssd_sparse_table.cc @@ -1,11 +1,11 @@ // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.h b/paddle/fluid/distributed/table/ssd_sparse_table.h index 796dd55e2e0804..a3fe7d25e4a3e3 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/table/ssd_sparse_table.h @@ -1,11 +1,11 @@ // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. From 2be19d7e2d82372ad282f89306e679e68a5b9e40 Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Mon, 24 May 2021 17:04:16 +0800 Subject: [PATCH 6/8] parse table class --- paddle/fluid/distributed/table/depends/rocksdb_warpper.h | 2 +- paddle/fluid/distributed/table/ssd_sparse_table.h | 2 +- python/paddle/distributed/fleet/runtime/the_one_ps.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h index 34e23af03279d5..39d2872c77fe67 100644 --- a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -3,7 +3,7 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.h b/paddle/fluid/distributed/table/ssd_sparse_table.h index a3fe7d25e4a3e3..f26c89be1609a0 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/table/ssd_sparse_table.h @@ -3,7 +3,7 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 91b70282295135..ac968f645575e0 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -46,7 +46,7 @@ def parse_table_class(varname, o_main_program): param_name = op.input("W")[0] if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2": - if op.has_attr('table_class'): + if op.has_attr('table_class') and op.attr("table_class") != "none": return op.attr('table_class') else: return "CommonSparseTable" From eba967ba38169b2ef1c2e974666598c53d0a3a7c Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Mon, 24 May 2021 20:58:47 +0800 Subject: [PATCH 7/8] code style --- paddle/fluid/distributed/table/ssd_sparse_table.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.cc b/paddle/fluid/distributed/table/ssd_sparse_table.cc index 098f11173a99e3..e2cbcf0a991d75 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/table/ssd_sparse_table.cc @@ -3,7 +3,7 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software From fea5a161f803a8daeb93ad35545f8129c18247b6 Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Wed, 26 May 2021 17:25:22 +0800 Subject: [PATCH 8/8] add define --- cmake/third_party.cmake | 6 ++++-- paddle/fluid/distributed/table/CMakeLists.txt | 14 ++++++++++---- .../distributed/table/depends/rocksdb_warpper.h | 2 ++ paddle/fluid/distributed/table/ssd_sparse_table.cc | 2 ++ paddle/fluid/distributed/table/ssd_sparse_table.h | 3 ++- paddle/fluid/distributed/table/table.cc | 4 ++++ 6 files changed, 24 insertions(+), 7 deletions(-) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 0dd7a86df26573..4d4ce060bfea82 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -305,8 +305,10 @@ if (WITH_PSCORE) include(external/libmct) # download, build, install libmct list(APPEND third_party_deps extern_libmct) - include(external/rocksdb) # download, build, install libmct - list(APPEND third_party_deps extern_rocksdb) + if (WITH_HETERPS) + include(external/rocksdb) # download, build, install libmct + list(APPEND third_party_deps extern_rocksdb) + endif() endif() if(WITH_XBYAK) diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index e00c5594997fe9..c928ebe90ceb9e 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -16,10 +16,17 @@ set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DIS get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) -cc_library(common_table SRCS common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc -sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} +set(EXTERN_DEP "") +if(WITH_HETERPS) + set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) + set(EXTERN_DEP rocksdb) +else() + set(TABLE_SRC common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) +endif() + +cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS} ${RPC_DEPS} graph_edge graph_node device_context string_helper -simple_threadpool xxhash generator rocksdb) +simple_threadpool xxhash generator ${EXTERN_DEP}) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) @@ -29,4 +36,3 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) -#target_link_libraries(table -lbz2) diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h index 39d2872c77fe67..0e25a89cb14d72 100644 --- a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef PADDLE_WITH_HETERPS #include #include #include @@ -154,3 +155,4 @@ class RocksDBHandler { }; } } +#endif diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.cc b/paddle/fluid/distributed/table/ssd_sparse_table.cc index 515c42a266df71..5de6de3d2909d6 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/table/ssd_sparse_table.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/distributed/table/ssd_sparse_table.h" DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); @@ -358,3 +359,4 @@ int64_t SSDSparseTable::LoadFromText( } // namespace ps } // namespace paddle +#endif diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.h b/paddle/fluid/distributed/table/ssd_sparse_table.h index bb62648b08b5b8..5e85fa3ce59d13 100644 --- a/paddle/fluid/distributed/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/table/ssd_sparse_table.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/depends/rocksdb_warpper.h" - +#ifdef PADDLE_WITH_HETERPS namespace paddle { namespace distributed { class SSDSparseTable : public CommonSparseTable { @@ -58,3 +58,4 @@ class SSDSparseTable : public CommonSparseTable { } // namespace ps } // namespace paddle +#endif diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc index 25884387aaecf9..0f8753c0746341 100644 --- a/paddle/fluid/distributed/table/table.cc +++ b/paddle/fluid/distributed/table/table.cc @@ -21,7 +21,9 @@ #include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/sparse_geo_table.h" +#ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/distributed/table/ssd_sparse_table.h" +#endif #include "paddle/fluid/distributed/table/tensor_accessor.h" #include "paddle/fluid/distributed/table/tensor_table.h" @@ -30,7 +32,9 @@ namespace distributed { REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable); +#ifdef PADDLE_WITH_HETERPS REGISTER_PSCORE_CLASS(Table, SSDSparseTable); +#endif REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, TensorTable);