Skip to content

Commit a9b5445

Browse files
committed
merge
2 parents f43d085 + 676a92c commit a9b5445

File tree

138 files changed

+3866
-1698
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+3866
-1698
lines changed

paddle/fluid/distributed/ps/table/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRI
3535
set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3636
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3737
set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
38+
set_source_files_properties(ctr_dymf_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3839
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3940
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
4041
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
4142

4243
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
43-
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
44+
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc ctr_dymf_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
4445
cc_library(sparse_table SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table rocksdb)
4546

4647
cc_library(table SRCS table.cc DEPS sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)

paddle/fluid/distributed/ps/table/common_graph_table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ class GraphTable : public Table {
566566
int32_t dump_edges_to_ssd(int idx);
567567
int32_t get_partition_num(int idx) { return partitions[idx].size(); }
568568
std::vector<int64_t> get_partition(int idx, int index) {
569-
if (idx >= partitions.size() || index >= partitions[idx].size())
569+
if (idx >= (int)partitions.size() || index >= (int)partitions[idx].size())
570570
return std::vector<int64_t>();
571571
return partitions[idx][index];
572572
}
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
16+
#include <gflags/gflags.h>
17+
#include "glog/logging.h"
18+
#include "paddle/fluid/string/string_helper.h"
19+
20+
namespace paddle {
21+
namespace distributed {
22+
23+
int CtrDymfAccessor::Initialize() {
24+
auto name = _config.embed_sgd_param().name();
25+
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
26+
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
27+
28+
name = _config.embedx_sgd_param().name();
29+
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
30+
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
31+
_config.embedx_dim());
32+
33+
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
34+
common_feature_value.embedx_dim = _config.embedx_dim();
35+
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
36+
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
37+
_ssd_unseenday_threshold =
38+
_config.ctr_accessor_param().ssd_unseenday_threshold();
39+
40+
if (_config.ctr_accessor_param().show_scale()) {
41+
_show_scale = true;
42+
}
43+
VLOG(0) << " INTO CtrDymfAccessor::Initialize()";
44+
InitAccessorInfo();
45+
return 0;
46+
}
47+
48+
void CtrDymfAccessor::InitAccessorInfo() {
49+
_accessor_info.dim = common_feature_value.Dim();
50+
_accessor_info.size = common_feature_value.Size();
51+
52+
auto embedx_dim = _config.embedx_dim();
53+
VLOG(0) << "InitAccessorInfo embedx_dim:" << embedx_dim;
54+
_accessor_info.select_dim = 3 + embedx_dim;
55+
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
56+
_accessor_info.update_dim = 4 + embedx_dim;
57+
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
58+
_accessor_info.mf_size =
59+
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
60+
}
61+
62+
bool CtrDymfAccessor::Shrink(float* value) {
63+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
64+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
65+
auto delete_after_unseen_days =
66+
_config.ctr_accessor_param().delete_after_unseen_days();
67+
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
68+
69+
// time_decay first
70+
common_feature_value.Show(value) *= _show_click_decay_rate;
71+
common_feature_value.Click(value) *= _show_click_decay_rate;
72+
73+
// shrink after
74+
auto score = ShowClickScore(common_feature_value.Show(value),
75+
common_feature_value.Click(value));
76+
auto unseen_days = common_feature_value.UnseenDays(value);
77+
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
78+
return true;
79+
}
80+
return false;
81+
}
82+
83+
bool CtrDymfAccessor::SaveCache(float* value, int param,
84+
double global_cache_threshold) {
85+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
86+
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
87+
if (ShowClickScore(common_feature_value.Show(value),
88+
common_feature_value.Click(value)) >= base_threshold &&
89+
common_feature_value.UnseenDays(value) <= delta_keep_days) {
90+
return common_feature_value.Show(value) > global_cache_threshold;
91+
}
92+
return false;
93+
}
94+
95+
bool CtrDymfAccessor::SaveSSD(float* value) {
96+
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
97+
return true;
98+
}
99+
return false;
100+
}
101+
102+
bool CtrDymfAccessor::Save(float* value, int param) {
103+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
104+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
105+
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
106+
if (param == 2) {
107+
delta_threshold = 0;
108+
}
109+
switch (param) {
110+
// save all
111+
case 0: {
112+
return true;
113+
}
114+
// save xbox delta
115+
case 1:
116+
// save xbox base
117+
case 2: {
118+
if (ShowClickScore(common_feature_value.Show(value),
119+
common_feature_value.Click(value)) >= base_threshold &&
120+
common_feature_value.DeltaScore(value) >= delta_threshold &&
121+
common_feature_value.UnseenDays(value) <= delta_keep_days) {
122+
// do this after save, because it must not be modified when retry
123+
if (param == 2) {
124+
common_feature_value.DeltaScore(value) = 0;
125+
}
126+
return true;
127+
} else {
128+
return false;
129+
}
130+
}
131+
// already decayed in shrink
132+
case 3: {
133+
// do this after save, because it must not be modified when retry
134+
// common_feature_value.UnseenDays(value)++;
135+
return true;
136+
}
137+
// save revert batch_model
138+
case 5: {
139+
return true;
140+
}
141+
default:
142+
return true;
143+
}
144+
}
145+
146+
void CtrDymfAccessor::UpdateStatAfterSave(float* value, int param) {
147+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
148+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
149+
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
150+
if (param == 2) {
151+
delta_threshold = 0;
152+
}
153+
switch (param) {
154+
case 1: {
155+
if (ShowClickScore(common_feature_value.Show(value),
156+
common_feature_value.Click(value)) >= base_threshold &&
157+
common_feature_value.DeltaScore(value) >= delta_threshold &&
158+
common_feature_value.UnseenDays(value) <= delta_keep_days) {
159+
common_feature_value.DeltaScore(value) = 0;
160+
}
161+
}
162+
return;
163+
case 3: {
164+
common_feature_value.UnseenDays(value)++;
165+
}
166+
return;
167+
default:
168+
return;
169+
}
170+
}
171+
172+
int32_t CtrDymfAccessor::Create(float** values, size_t num) {
173+
auto embedx_dim = _config.embedx_dim();
174+
for (size_t value_item = 0; value_item < num; ++value_item) {
175+
float* value = values[value_item];
176+
value[common_feature_value.UnseenDaysIndex()] = 0;
177+
value[common_feature_value.DeltaScoreIndex()] = 0;
178+
value[common_feature_value.ShowIndex()] = 0;
179+
value[common_feature_value.ClickIndex()] = 0;
180+
value[common_feature_value.SlotIndex()] = -1;
181+
value[common_feature_value.MfDimIndex()] = -1;
182+
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
183+
value + common_feature_value.EmbedG2SumIndex());
184+
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
185+
value + common_feature_value.EmbedxG2SumIndex(),
186+
false);
187+
}
188+
return 0;
189+
}
190+
191+
bool CtrDymfAccessor::NeedExtendMF(float* value) {
192+
float show = value[common_feature_value.ShowIndex()];
193+
float click = value[common_feature_value.ClickIndex()];
194+
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
195+
click * _config.ctr_accessor_param().click_coeff();
196+
return score >= _config.embedx_threshold();
197+
}
198+
199+
bool CtrDymfAccessor::HasMF(size_t size) {
200+
return size > common_feature_value.EmbedxG2SumIndex();
201+
}
202+
203+
// from CommonFeatureValue to CtrDymfPullValue
204+
int32_t CtrDymfAccessor::Select(float** select_values, const float** values,
205+
size_t num) {
206+
auto embedx_dim = _config.embedx_dim();
207+
for (size_t value_item = 0; value_item < num; ++value_item) {
208+
float* select_value = select_values[value_item];
209+
const float* value = values[value_item];
210+
select_value[CtrDymfPullValue::ShowIndex()] =
211+
value[common_feature_value.ShowIndex()];
212+
select_value[CtrDymfPullValue::ClickIndex()] =
213+
value[common_feature_value.ClickIndex()];
214+
select_value[CtrDymfPullValue::EmbedWIndex()] =
215+
value[common_feature_value.EmbedWIndex()];
216+
memcpy(select_value + CtrDymfPullValue::EmbedxWIndex(),
217+
value + common_feature_value.EmbedxWIndex(),
218+
embedx_dim * sizeof(float));
219+
}
220+
return 0;
221+
}
222+
223+
// from CtrDymfPushValue to CtrDymfPushValue
224+
// first dim: item
225+
// second dim: field num
226+
int32_t CtrDymfAccessor::Merge(float** update_values,
227+
const float** other_update_values, size_t num) {
228+
// currently merge in cpu is not supported
229+
return 0;
230+
}
231+
232+
// from CtrDymfPushValue to CommonFeatureValue
233+
// first dim: item
234+
// second dim: field num
235+
int32_t CtrDymfAccessor::Update(float** update_values,
236+
const float** push_values, size_t num) {
237+
// currently update in cpu is not supported
238+
return 0;
239+
}
240+
241+
bool CtrDymfAccessor::CreateValue(int stage, const float* value) {
242+
// stage == 0, pull
243+
// stage == 1, push
244+
if (stage == 0) {
245+
return true;
246+
} else if (stage == 1) {
247+
// operation
248+
auto show = CtrDymfPushValue::Show(const_cast<float*>(value));
249+
auto click = CtrDymfPushValue::Click(const_cast<float*>(value));
250+
auto score = ShowClickScore(show, click);
251+
if (score <= 0) {
252+
return false;
253+
}
254+
if (score >= 1) {
255+
return true;
256+
}
257+
return local_uniform_real_distribution<float>()(local_random_engine()) <
258+
score;
259+
} else {
260+
return true;
261+
}
262+
}
263+
264+
float CtrDymfAccessor::ShowClickScore(float show, float click) {
265+
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
266+
auto click_coeff = _config.ctr_accessor_param().click_coeff();
267+
return (show - click) * nonclk_coeff + click * click_coeff;
268+
}
269+
270+
std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
271+
/*
272+
float unseen_days;
273+
float delta_score;
274+
float show;
275+
float click;
276+
float embed_w;
277+
std::vector<float> embed_g2sum; // float embed_g2sum
278+
float slot;
279+
float mf_dim;
280+
std::<vector>float embedx_g2sum; // float embedx_g2sum
281+
std::vector<float> embedx_w;
282+
*/
283+
thread_local std::ostringstream os;
284+
os.clear();
285+
os.str("");
286+
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4];
287+
// << v[5] << " " << v[6];
288+
for (int i = common_feature_value.EmbedG2SumIndex();
289+
i < common_feature_value.EmbedxWIndex(); i++) {
290+
os << " " << v[i];
291+
}
292+
os << " " << common_feature_value.Slot(const_cast<float*>(v)) << " "
293+
<< common_feature_value.MfDim(const_cast<float*>(v));
294+
auto show = common_feature_value.Show(const_cast<float*>(v));
295+
auto click = common_feature_value.Click(const_cast<float*>(v));
296+
auto score = ShowClickScore(show, click);
297+
if (score >= _config.embedx_threshold() &&
298+
param > common_feature_value.EmbedxG2SumIndex()) {
299+
VLOG(0) << "common_feature_value.EmbedxG2SumIndex():"
300+
<< common_feature_value.EmbedxG2SumIndex();
301+
for (auto i = common_feature_value.EmbedxG2SumIndex();
302+
i < common_feature_value.Dim(); ++i) {
303+
os << " " << v[i];
304+
}
305+
}
306+
return os.str();
307+
}
308+
309+
int CtrDymfAccessor::ParseFromString(const std::string& str, float* value) {
310+
auto ret = paddle::string::str_to_float(str.data(), value);
311+
CHECK(ret >= 7) << "expect more than 7 real:" << ret;
312+
return ret;
313+
}
314+
315+
} // namespace distributed
316+
} // namespace paddle

0 commit comments

Comments
 (0)