Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions vowpalwabbit/stagewise_poly.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "gd.h"
#include "accumulate.h"
#include "label_parser.h"
#include "reductions.h"
#include "vw.h"
#include "vw_allreduce.h"
Expand All @@ -34,34 +35,34 @@ struct sort_data

struct stagewise_poly
{
vw *all; // many uses, unmodular reduction
vw* all = nullptr; // many uses, unmodular reduction

float sched_exponent;
uint32_t batch_sz;
bool batch_sz_double;
float sched_exponent = 0.f;
uint32_t batch_sz = 0;
bool batch_sz_double = false;

sort_data *sd;
size_t sd_len;
uint8_t *depthsbits; // interleaved array storing depth information and parent/cycle bits
sort_data* sd = nullptr;
size_t sd_len = 0;
uint8_t* depthsbits = nullptr; // interleaved array storing depth information and parent/cycle bits

uint64_t sum_sparsity; // of synthetic example
uint64_t sum_input_sparsity; // of input example
uint64_t num_examples;
uint64_t sum_sparsity = 0; // of synthetic example
uint64_t sum_input_sparsity = 0; // of input example
uint64_t num_examples = 0;
// following three are for parallel (see end_pass())
uint64_t sum_sparsity_sync;
uint64_t sum_input_sparsity_sync;
uint64_t num_examples_sync;
uint64_t sum_sparsity_sync = 0;
uint64_t sum_input_sparsity_sync = 0;
uint64_t num_examples_sync = 0;

example synth_ec;
// following is bookkeeping in synth_ec creation (dfs)
feature synth_rec_f;
example *original_ec;
uint32_t cur_depth;
bool training;
uint64_t last_example_counter;
size_t numpasses;
uint32_t next_batch_sz;
bool update_support;
feature synth_rec_f{0.f, 0};
example* original_ec = nullptr;
uint32_t cur_depth = 0;
bool training = false;
uint64_t last_example_counter = 0;
size_t numpasses = 0;
uint32_t next_batch_sz = 0;
bool update_support = false;

#ifdef DEBUG
uint32_t max_depth;
Expand Down Expand Up @@ -655,7 +656,7 @@ base_learner* stagewise_poly_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();
vw& all = *stack_builder.get_all_pointer();
auto poly = scoped_calloc_or_throw<stagewise_poly>();
auto poly = VW::make_unique<stagewise_poly>();
bool stage_poly = false;
option_group_definition new_options("Stagewise polynomial options");
new_options
Expand All @@ -675,8 +676,8 @@ base_learner* stagewise_poly_setup(VW::setup_base_i& stack_builder)
if (!options.add_parse_and_check_necessary(new_options)) return nullptr;

poly->all = &all;
depthsbits_create(*poly.get());
sort_data_create(*poly.get());
depthsbits_create(*poly);
sort_data_create(*poly);

poly->batch_sz_double = !poly->batch_sz_double;

Expand All @@ -692,12 +693,14 @@ base_learner* stagewise_poly_setup(VW::setup_base_i& stack_builder)
poly->original_ec = nullptr;
poly->next_batch_sz = poly->batch_sz;

learner<stagewise_poly, example>& l = init_learner(poly, as_singleline(stack_builder.setup_base_learner()), learn,
predict, stack_builder.get_setupfn_name(stagewise_poly_setup));
auto* l = VW::LEARNER::make_reduction_learner(std::move(poly), as_singleline(stack_builder.setup_base_learner()),
learn, predict, stack_builder.get_setupfn_name(stagewise_poly_setup))
.set_label_type(label_type_t::simple)
.set_prediction_type(prediction_type_t::scalar)
.set_save_load(save_load)
.set_finish_example(finish_example)
.set_end_pass(end_pass)
.build();

l.set_save_load(save_load);
l.set_finish_example(finish_example);
l.set_end_pass(end_pass);

return make_base(l);
return make_base(*l);
}