Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions vowpalwabbit/kernel_svm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,33 +84,34 @@ void free_svm_model(svm_model* model)

struct svm_params
{
size_t current_pass;
bool active;
bool active_pool_greedy;
bool para_active;
double active_c;

size_t pool_size;
size_t pool_pos;
size_t subsample; // NOTE: Eliminating subsample to only support 1/pool_size
size_t reprocess;

svm_model* model;
size_t maxcache;
size_t current_pass = 0;
bool active = false;
bool active_pool_greedy = false;
bool para_active = false;
double active_c = 0.0;

size_t pool_size = 0;
size_t pool_pos = 0;
size_t subsample = 0; // NOTE: Eliminating subsample to only support 1/pool_size
size_t reprocess = 0;

svm_model* model = nullptr;
size_t maxcache = 0;
// size_t curcache;

svm_example** pool;
float lambda;
svm_example** pool = nullptr;
float lambda = 0.f;

void* kernel_params;
size_t kernel_type;
void* kernel_params = nullptr;
size_t kernel_type = 0;

size_t local_begin, local_end;
size_t current_t;
size_t local_begin = 0;
size_t local_end = 0;
size_t current_t = 0;

float loss_sum;
float loss_sum = 0.f;

vw* all; // flatten, parallel
vw* all = nullptr; // flatten, parallel
std::shared_ptr<rand_state> _random_state;

~svm_params()
Expand Down Expand Up @@ -438,7 +439,7 @@ void predict(svm_params& params, svm_example** ec_arr, float* scores, size_t n)
}
}

void predict(svm_params& params, single_learner&, example& ec)
void predict(svm_params& params, base_learner&, example& ec)
{
flat_example* fec = flatten_sort_example(*(params.all), &ec);
if (fec)
Expand Down Expand Up @@ -762,7 +763,7 @@ void train(svm_params& params)
free(train_pool);
}

void learn(svm_params& params, single_learner&, example& ec)
void learn(svm_params& params, base_learner&, example& ec)
{
flat_example* fec = flatten_sort_example(*(params.all), &ec);
if (fec)
Expand Down Expand Up @@ -799,7 +800,7 @@ VW::LEARNER::base_learner* kernel_svm_setup(VW::setup_base_i& stack_builder)
options_i& options = *stack_builder.get_options();
vw& all = *stack_builder.get_all_pointer();

auto params = scoped_calloc_or_throw<svm_params>();
auto params = VW::make_unique<svm_params>();
std::string kernel_type;
float bandwidth = 1.f;
int degree = 2;
Expand Down Expand Up @@ -872,8 +873,10 @@ VW::LEARNER::base_learner* kernel_svm_setup(VW::setup_base_i& stack_builder)

params->all->weights.stride_shift(0);

learner<svm_params, example>& l =
init_learner(params, learn, predict, 1, stack_builder.get_setupfn_name(kernel_svm_setup));
l.set_save_load(save_load);
return make_base(l);
auto* l = make_base_learner(std::move(params), learn, predict, stack_builder.get_setupfn_name(kernel_svm_setup),
prediction_type_t::scalar, label_type_t::simple)
.set_save_load(save_load)
.build();

return make_base(*l);
}
57 changes: 30 additions & 27 deletions vowpalwabbit/lda_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#endif

using namespace VW::config;
using namespace VW::LEARNER;

namespace logger = VW::io::logger;

Expand All @@ -57,15 +58,15 @@ class index_feature

struct lda
{
size_t topics;
float lda_alpha;
float lda_rho;
float lda_D;
float lda_epsilon;
size_t minibatch;
size_t topics = 0;
float lda_alpha = 0.f;
float lda_rho = 0.f;
float lda_D = 0.f;
float lda_epsilon = 0.f;
size_t minibatch = 0;
lda_math_mode mmode;

size_t finish_example_count;
size_t finish_example_count = 0;

v_array<float> Elogtheta;
v_array<float> decay_levels;
Expand All @@ -77,16 +78,16 @@ struct lda
v_array<float> v;
std::vector<index_feature> sorted_features;

bool compute_coherence_metrics;
bool compute_coherence_metrics = false;

// size by 1 << bits
std::vector<uint32_t> feature_counts;
std::vector<std::vector<size_t>> feature_to_example_map;

bool total_lambda_init;
bool total_lambda_init = false;

double example_t;
vw *all; // regressor, lda
double example_t = 0.0;
vw* all = nullptr; // regressor, lda

static constexpr float underflow_threshold = 1.0e-10f;
inline float digamma(float x);
Expand Down Expand Up @@ -963,7 +964,7 @@ void learn_batch(lda &l)
l.doc_lengths.clear();
}

void learn(lda &l, VW::LEARNER::single_learner &, example &ec)
void learn(lda& l, base_learner&, example& ec)
{
uint32_t num_ex = static_cast<uint32_t>(l.examples.size());
l.examples.push_back(&ec);
Expand All @@ -980,7 +981,7 @@ void learn(lda &l, VW::LEARNER::single_learner &, example &ec)
if (++num_ex == l.minibatch) learn_batch(l);
}

void learn_with_metrics(lda &l, VW::LEARNER::single_learner &base, example &ec)
void learn_with_metrics(lda& l, base_learner& base, example& ec)
{
if (l.all->passes_complete == 0)
{
Expand All @@ -1003,8 +1004,8 @@ void learn_with_metrics(lda &l, VW::LEARNER::single_learner &base, example &ec)
}

// placeholder
void predict(lda &l, VW::LEARNER::single_learner &base, example &ec) { learn(l, base, ec); }
void predict_with_metrics(lda &l, VW::LEARNER::single_learner &base, example &ec) { learn_with_metrics(l, base, ec); }
void predict(lda& l, base_learner& base, example& ec) { learn(l, base, ec); }
void predict_with_metrics(lda& l, base_learner& base, example& ec) { learn_with_metrics(l, base, ec); }

struct word_doc_frequency
{
Expand Down Expand Up @@ -1273,12 +1274,12 @@ std::istream &operator>>(std::istream &in, lda_math_mode &mmode)
return in;
}

VW::LEARNER::base_learner* lda_setup(VW::setup_base_i& stack_builder)
base_learner* lda_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();
vw& all = *stack_builder.get_all_pointer();

auto ld = scoped_calloc_or_throw<lda>();
auto ld = VW::make_unique<lda>();
option_group_definition new_options("Latent Dirichlet Allocation");
int math_mode;
new_options.add(make_option("lda", ld->topics).keep().necessary().help("Run lda with <int> topics"))
Expand Down Expand Up @@ -1341,14 +1342,16 @@ VW::LEARNER::base_learner* lda_setup(VW::setup_base_i& stack_builder)

all.example_parser->lbl_parser = no_label::no_label_parser;

VW::LEARNER::learner<lda, example>& l = init_learner(ld, ld->compute_coherence_metrics ? learn_with_metrics : learn,
ld->compute_coherence_metrics ? predict_with_metrics : predict, UINT64_ONE << all.weights.stride_shift(),
prediction_type_t::scalars, stack_builder.get_setupfn_name(lda_setup), true);

l.set_save_load(save_load);
l.set_finish_example(finish_example);
l.set_end_examples(end_examples);
l.set_end_pass(end_pass);

return make_base(l);
auto* l = make_base_learner(std::move(ld), ld->compute_coherence_metrics ? learn_with_metrics : learn,
ld->compute_coherence_metrics ? predict_with_metrics : predict, stack_builder.get_setupfn_name(lda_setup),
prediction_type_t::scalars, label_type_t::nolabel)
.set_params_per_weight(UINT64_ONE << all.weights.stride_shift())
.set_learn_returns_prediction(true)
.set_save_load(save_load)
.set_finish_example(finish_example)
.set_end_examples(end_examples)
.set_end_pass(end_pass)
.build();

return make_base(*l);
}
28 changes: 19 additions & 9 deletions vowpalwabbit/lrq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ using namespace VW::config;

struct LRQstate
{
vw* all; // feature creation, audit, hash_inv
vw* all = nullptr; // feature creation, audit, hash_inv
bool lrindices[256];
size_t orig_size[256];
std::set<std::string> lrpairs;
bool dropout;
uint64_t seed;
uint64_t initial_seed;
bool dropout = false;
uint64_t seed = 0;
uint64_t initial_seed = 0;

LRQstate()
{
std::fill(lrindices, lrindices + 256, false);
std::fill(orig_size, orig_size + 256, 0);
}
};

bool valid_int(const char* s)
Expand Down Expand Up @@ -167,7 +173,7 @@ base_learner* lrq_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();
vw& all = *stack_builder.get_all_pointer();
auto lrq = scoped_calloc_or_throw<LRQstate>();
auto lrq = VW::make_unique<LRQstate>();
std::vector<std::string> lrq_names;
option_group_definition new_options("Low Rank Quadratics");
new_options.add(make_option("lrq", lrq_names).keep().necessary().help("use low rank quadratic features"))
Expand Down Expand Up @@ -213,10 +219,14 @@ base_learner* lrq_setup(VW::setup_base_i& stack_builder)

all.wpp = all.wpp * static_cast<uint64_t>(1 + maxk);
auto base = stack_builder.setup_base_learner();
learner<LRQstate, example>& l = init_learner(lrq, as_singleline(base), predict_or_learn<true>,
predict_or_learn<false>, 1 + maxk, stack_builder.get_setupfn_name(lrq_setup), base->learn_returns_prediction);
l.set_end_pass(reset_seed);

auto* l = make_reduction_learner(std::move(lrq), as_singleline(base), predict_or_learn<true>, predict_or_learn<false>,
stack_builder.get_setupfn_name(lrq_setup))
.set_params_per_weight(1 + maxk)
.set_learn_returns_prediction(base->learn_returns_prediction)
.set_end_pass(reset_seed)
.build();

// TODO: leaks memory ?
return make_base(l);
return make_base(*l);
}
26 changes: 18 additions & 8 deletions vowpalwabbit/lrqfa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ using namespace VW::config;

struct LRQFAstate
{
vw* all;
std::string field_name;
int k;
vw* all = nullptr;
std::string field_name = "";
int k = 0;
int field_id[256];
size_t orig_size[256];

LRQFAstate()
{
std::fill(field_id, field_id + 256, 0);
std::fill(orig_size, orig_size + 256, 0);
}
};

inline float cheesyrand(uint64_t x)
Expand Down Expand Up @@ -143,7 +149,7 @@ VW::LEARNER::base_learner* lrqfa_setup(VW::setup_base_i& stack_builder)

if (!options.add_parse_and_check_necessary(new_options)) return nullptr;

auto lrq = scoped_calloc_or_throw<LRQFAstate>();
auto lrq = VW::make_unique<LRQFAstate>();
lrq->all = &all;

std::string lrqopt = VW::decode_inline_hex(lrqfa);
Expand All @@ -156,9 +162,13 @@ VW::LEARNER::base_learner* lrqfa_setup(VW::setup_base_i& stack_builder)

all.wpp = all.wpp * static_cast<uint64_t>(1 + lrq->k);
auto base = stack_builder.setup_base_learner();
learner<LRQFAstate, example>& l = init_learner(lrq, as_singleline(base), predict_or_learn<true>,
predict_or_learn<false>, 1 + lrq->field_name.size() * lrq->k, stack_builder.get_setupfn_name(lrqfa_setup),
base->learn_returns_prediction);
size_t ws = 1 + lrq->field_name.size() * lrq->k;

auto* l = make_reduction_learner(std::move(lrq), as_singleline(base), predict_or_learn<true>, predict_or_learn<false>,
stack_builder.get_setupfn_name(lrqfa_setup))
.set_params_per_weight(ws)
.set_learn_returns_prediction(base->learn_returns_prediction)
.build();

return make_base(l);
return make_base(*l);
}
Loading