Skip to content

Commit cea1ba8

Browse files
add ctr accessor (#36601)
1 parent 19b02d9 commit cea1ba8

8 files changed

Lines changed: 893 additions & 48 deletions

File tree

paddle/fluid/distributed/ps.proto

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,11 @@ message TableParameter {
119119

120120
message TableAccessorParameter {
121121
optional string accessor_class = 1;
122-
// optional SparseSGDRuleParameter sparse_sgd_param = 2;
123122
optional uint32 fea_dim = 4 [ default = 11 ];
124123
optional uint32 embedx_dim = 5 [ default = 8 ];
125124
optional uint32 embedx_threshold = 6 [ default = 10 ];
126125
optional CtrAccessorParameter ctr_accessor_param = 7;
127126
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
128-
// optional SparseCommonSGDRuleParameter sparse_commonsgd_param = 9;
129127
optional SparseCommonSGDRuleParameter embed_sgd_param = 10;
130128
optional SparseCommonSGDRuleParameter embedx_sgd_param = 11;
131129
}
@@ -182,13 +180,6 @@ message TableAccessorSaveParameter {
182180
optional string deconverter = 3;
183181
}
184182

185-
// message SparseSGDRuleParameter {
186-
// optional double learning_rate = 1 [default = 0.05];
187-
// optional double initial_g2sum = 2 [default = 3.0];
188-
// optional double initial_range = 3 [default = 0.0001];
189-
// repeated float weight_bounds = 4;
190-
//}
191-
192183
message SparseCommonSGDRuleParameter {
193184
optional string name = 1;
194185
optional SparseNaiveSGDRuleParameter naive = 2;

paddle/fluid/distributed/table/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec
3636
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3737

3838
set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
39+
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3940
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
41+
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
4042

41-
42-
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost sparse_sgd_rule)
43+
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost ctr_accessor)
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/distributed/table/ctr_accessor.h"
16+
#include <gflags/gflags.h>
17+
#include "glog/logging.h"
18+
#include "paddle/fluid/string/string_helper.h"
19+
20+
namespace paddle {
21+
namespace distributed {
22+
23+
int CtrCommonAccessor::initialize() {
24+
auto name = _config.embed_sgd_param().name();
25+
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
26+
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
27+
28+
name = _config.embedx_sgd_param().name();
29+
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
30+
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
31+
_config.embedx_dim());
32+
33+
common_feature_value.embed_sgd_dim = _embed_sgd_rule->dim();
34+
common_feature_value.embedx_dim = _config.embedx_dim();
35+
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim();
36+
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
37+
38+
return 0;
39+
}
40+
41+
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
42+
43+
size_t CtrCommonAccessor::dim_size(size_t dim) {
44+
auto embedx_dim = _config.embedx_dim();
45+
return common_feature_value.dim_size(dim, embedx_dim);
46+
}
47+
48+
size_t CtrCommonAccessor::size() { return common_feature_value.size(); }
49+
50+
size_t CtrCommonAccessor::mf_size() {
51+
return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) *
52+
sizeof(float); // embedx embedx_g2sum
53+
}
54+
55+
// pull value
56+
size_t CtrCommonAccessor::select_dim() {
57+
auto embedx_dim = _config.embedx_dim();
58+
return 1 + embedx_dim;
59+
}
60+
61+
size_t CtrCommonAccessor::select_dim_size(size_t dim) { return sizeof(float); }
62+
63+
size_t CtrCommonAccessor::select_size() { return select_dim() * sizeof(float); }
64+
65+
// push value
66+
size_t CtrCommonAccessor::update_dim() {
67+
auto embedx_dim = _config.embedx_dim();
68+
return 4 + embedx_dim;
69+
}
70+
71+
size_t CtrCommonAccessor::update_dim_size(size_t dim) { return sizeof(float); }
72+
73+
size_t CtrCommonAccessor::update_size() { return update_dim() * sizeof(float); }
74+
75+
bool CtrCommonAccessor::shrink(float* value) {
76+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
77+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
78+
auto delete_after_unseen_days =
79+
_config.ctr_accessor_param().delete_after_unseen_days();
80+
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
81+
82+
// time_decay first
83+
common_feature_value.show(value) *= _show_click_decay_rate;
84+
common_feature_value.click(value) *= _show_click_decay_rate;
85+
86+
// shrink after
87+
auto score = show_click_score(common_feature_value.show(value),
88+
common_feature_value.click(value));
89+
auto unseen_days = common_feature_value.unseen_days(value);
90+
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
91+
return true;
92+
}
93+
return false;
94+
}
95+
96+
bool CtrCommonAccessor::save(float* value, int param) {
97+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
98+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
99+
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
100+
if (param == 2) {
101+
delta_threshold = 0;
102+
}
103+
switch (param) {
104+
// save all
105+
case 0: {
106+
return true;
107+
}
108+
// save xbox delta
109+
case 1:
110+
// save xbox base
111+
case 2: {
112+
if (show_click_score(common_feature_value.show(value),
113+
common_feature_value.click(value)) >=
114+
base_threshold &&
115+
common_feature_value.delta_score(value) >= delta_threshold &&
116+
common_feature_value.unseen_days(value) <= delta_keep_days) {
117+
// do this after save, because it must not be modified when retry
118+
if (param == 2) {
119+
common_feature_value.delta_score(value) = 0;
120+
}
121+
return true;
122+
} else {
123+
return false;
124+
}
125+
}
126+
// already decayed in shrink
127+
case 3: {
128+
// do this after save, because it must not be modified when retry
129+
// common_feature_value.unseen_days(value)++;
130+
return true;
131+
}
132+
// save revert batch_model
133+
case 5: {
134+
return true;
135+
}
136+
default:
137+
return true;
138+
}
139+
}
140+
141+
void CtrCommonAccessor::update_stat_after_save(float* value, int param) {
142+
auto base_threshold = _config.ctr_accessor_param().base_threshold();
143+
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
144+
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
145+
if (param == 2) {
146+
delta_threshold = 0;
147+
}
148+
switch (param) {
149+
case 1: {
150+
if (show_click_score(common_feature_value.show(value),
151+
common_feature_value.click(value)) >=
152+
base_threshold &&
153+
common_feature_value.delta_score(value) >= delta_threshold &&
154+
common_feature_value.unseen_days(value) <= delta_keep_days) {
155+
common_feature_value.delta_score(value) = 0;
156+
}
157+
}
158+
return;
159+
case 3: {
160+
common_feature_value.unseen_days(value)++;
161+
}
162+
return;
163+
default:
164+
return;
165+
}
166+
}
167+
168+
int32_t CtrCommonAccessor::create(float** values, size_t num) {
169+
auto embedx_dim = _config.embedx_dim();
170+
for (size_t value_item = 0; value_item < num; ++value_item) {
171+
float* value = values[value_item];
172+
value[common_feature_value.unseen_days_index()] = 0;
173+
value[common_feature_value.delta_score_index()] = 0;
174+
value[common_feature_value.show_index()] = 0;
175+
value[common_feature_value.click_index()] = 0;
176+
value[common_feature_value.slot_index()] = -1;
177+
_embed_sgd_rule->init_value(
178+
value + common_feature_value.embed_w_index(),
179+
value + common_feature_value.embed_g2sum_index());
180+
_embedx_sgd_rule->init_value(
181+
value + common_feature_value.embedx_w_index(),
182+
value + common_feature_value.embedx_g2sum_index(), false);
183+
}
184+
return 0;
185+
}
186+
187+
bool CtrCommonAccessor::need_extend_mf(float* value) {
188+
float show = value[common_feature_value.show_index()];
189+
float click = value[common_feature_value.click_index()];
190+
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
191+
click * _config.ctr_accessor_param().click_coeff();
192+
return score >= _config.embedx_threshold();
193+
}
194+
195+
bool CtrCommonAccessor::has_mf(size_t size) {
196+
return size > common_feature_value.embedx_g2sum_index();
197+
}
198+
199+
// from CommonFeatureValue to CtrCommonPullValue
200+
int32_t CtrCommonAccessor::select(float** select_values, const float** values,
201+
size_t num) {
202+
auto embedx_dim = _config.embedx_dim();
203+
for (size_t value_item = 0; value_item < num; ++value_item) {
204+
float* select_value = select_values[value_item];
205+
const float* value = values[value_item];
206+
select_value[CtrCommonPullValue::embed_w_index()] =
207+
value[common_feature_value.embed_w_index()];
208+
memcpy(select_value + CtrCommonPullValue::embedx_w_index(),
209+
value + common_feature_value.embedx_w_index(),
210+
embedx_dim * sizeof(float));
211+
}
212+
return 0;
213+
}
214+
215+
// from CtrCommonPushValue to CtrCommonPushValue
216+
// first dim: item
217+
// second dim: field num
218+
int32_t CtrCommonAccessor::merge(float** update_values,
219+
const float** other_update_values,
220+
size_t num) {
221+
auto embedx_dim = _config.embedx_dim();
222+
size_t total_dim = CtrCommonPushValue::dim(embedx_dim);
223+
for (size_t value_item = 0; value_item < num; ++value_item) {
224+
float* update_value = update_values[value_item];
225+
const float* other_update_value = other_update_values[value_item];
226+
for (auto i = 0u; i < total_dim; ++i) {
227+
if (i != CtrCommonPushValue::slot_index()) {
228+
update_value[i] += other_update_value[i];
229+
}
230+
}
231+
}
232+
return 0;
233+
}
234+
235+
// from CtrCommonPushValue to CommonFeatureValue
236+
// first dim: item
237+
// second dim: field num
238+
int32_t CtrCommonAccessor::update(float** update_values,
239+
const float** push_values, size_t num) {
240+
auto embedx_dim = _config.embedx_dim();
241+
for (size_t value_item = 0; value_item < num; ++value_item) {
242+
float* update_value = update_values[value_item];
243+
const float* push_value = push_values[value_item];
244+
float push_show = push_value[CtrCommonPushValue::show_index()];
245+
float push_click = push_value[CtrCommonPushValue::click_index()];
246+
float slot = push_value[CtrCommonPushValue::slot_index()];
247+
update_value[common_feature_value.show_index()] += push_show;
248+
update_value[common_feature_value.click_index()] += push_click;
249+
update_value[common_feature_value.slot_index()] = slot;
250+
update_value[common_feature_value.delta_score_index()] +=
251+
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
252+
push_click * _config.ctr_accessor_param().click_coeff();
253+
update_value[common_feature_value.unseen_days_index()] = 0;
254+
_embed_sgd_rule->update_value(
255+
update_value + common_feature_value.embed_w_index(),
256+
update_value + common_feature_value.embed_g2sum_index(),
257+
push_value + CtrCommonPushValue::embed_g_index());
258+
_embedx_sgd_rule->update_value(
259+
update_value + common_feature_value.embedx_w_index(),
260+
update_value + common_feature_value.embedx_g2sum_index(),
261+
push_value + CtrCommonPushValue::embedx_g_index());
262+
}
263+
return 0;
264+
}
265+
266+
bool CtrCommonAccessor::create_value(int stage, const float* value) {
267+
// stage == 0, pull
268+
// stage == 1, push
269+
if (stage == 0) {
270+
return true;
271+
} else if (stage == 1) {
272+
// operation
273+
auto show = CtrCommonPushValue::show_const(value);
274+
auto click = CtrCommonPushValue::click_const(value);
275+
auto score = show_click_score(show, click);
276+
if (score <= 0) {
277+
return false;
278+
}
279+
if (score >= 1) {
280+
return true;
281+
}
282+
return local_uniform_real_distribution<float>()(local_random_engine()) <
283+
score;
284+
} else {
285+
return true;
286+
}
287+
}
288+
289+
float CtrCommonAccessor::show_click_score(float show, float click) {
290+
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
291+
auto click_coeff = _config.ctr_accessor_param().click_coeff();
292+
return (show - click) * nonclk_coeff + click * click_coeff;
293+
}
294+
295+
std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
296+
thread_local std::ostringstream os;
297+
os.clear();
298+
os.str("");
299+
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
300+
<< v[5];
301+
for (int i = common_feature_value.embed_g2sum_index();
302+
i < common_feature_value.embedx_w_index(); i++) {
303+
os << " " << v[i];
304+
}
305+
auto show = common_feature_value.show_const(v);
306+
auto click = common_feature_value.click_const(v);
307+
auto score = show_click_score(show, click);
308+
if (score >= _config.embedx_threshold()) {
309+
for (auto i = common_feature_value.embedx_w_index();
310+
i < common_feature_value.dim(); ++i) {
311+
os << " " << v[i];
312+
}
313+
}
314+
return os.str();
315+
}
316+
317+
int CtrCommonAccessor::parse_from_string(const std::string& str, float* value) {
318+
int embedx_dim = _config.embedx_dim();
319+
320+
_embedx_sgd_rule->init_value(
321+
value + common_feature_value.embedx_w_index(),
322+
value + common_feature_value.embedx_g2sum_index());
323+
auto ret = paddle::string::str_to_float(str.data(), value);
324+
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
325+
return ret;
326+
}
327+
328+
} // namespace distributed
329+
} // namespace paddle

0 commit comments

Comments
 (0)