Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/ml/neural_net/weight_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ void xavier_weight_initializer::operator()(float* first_weight,
}
}

uniform_weight_initializer::uniform_weight_initializer(
float lower_bound, float upper_bound, std::mt19937* random_engine)
: dist_(std::uniform_real_distribution<float>(lower_bound, upper_bound)),
random_engine_(*random_engine)
{}

void uniform_weight_initializer::operator()(float* first_weight,
float* last_weight) {
for (float* w = first_weight; w != last_weight; ++w) {
*w = dist_(random_engine_);
}
}

scalar_weight_initializer::scalar_weight_initializer(float scalar)
: scalar_(scalar) {}

Expand Down
26 changes: 26 additions & 0 deletions src/ml/neural_net/weight_init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ class xavier_weight_initializer {
std::mt19937& random_engine_;
};


class uniform_weight_initializer {
public:

/**
* Creates a weight initializer that performs Uniform initialization
*
* \param lower_bound The lower bound of the uniform distribution to be sampled
* \param upper_bound The upper bound of the uniform distribution to be sampled
* \param random_engine The random number generator to use, which must remain
* valid for the lifetime of this instance.
*/
uniform_weight_initializer(float lower_bound, float upper_bound,
std::mt19937* random_engine);

/**
* Initializes each value in uniformly at random in the range [-lower_bound, upper_bound]
*/
void operator()(float* first_weight, float* last_weight);

private:

std::uniform_real_distribution<float> dist_;
std::mt19937& random_engine_;
};

struct scalar_weight_initializer {
/**
* Creates a weight initializer that initializes all of the weights to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def create(
options["num_styles"] = len(style_dataset)
options["resnet_mlmodel_path"] = pretrained_resnet_model.get_model_path("coreml")
options["vgg_mlmodel_path"] = pretrained_vgg16_model.get_model_path("coreml")
options["pretrained_weights"] = params["pretrained_weights"]

model.train(style_dataset[style_feature], content_dataset[content_feature], options)
return StyleTransfer(model_proxy=model, name=name)
Expand Down
19 changes: 17 additions & 2 deletions src/toolkits/style_transfer/style_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
}
size_t num_styles = num_styles_iter->second;

auto pretrained_weights_iter = opts.find("pretrained_weights");
bool pretrained_weights = false;
if (pretrained_weights_iter != opts.end()) {
pretrained_weights = pretrained_weights_iter->second;
}
opts.erase(pretrained_weights_iter);

init_options(opts);

if (read_state<flexible_type>("random_seed") == FLEX_UNDEFINED) {
Expand All @@ -694,9 +701,11 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
add_or_update_state({{"random_seed", random_seed}});
}

int random_seed = read_state<int>("random_seed");

m_training_data_iterator =
create_iterator(content, style, /* repeat */ true,
/* training */ true, static_cast<int>(num_styles));
/* training */ true, random_seed);

m_training_compute_context = create_compute_context();
if (m_training_compute_context == nullptr) {
Expand All @@ -709,7 +718,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
{"styles", style_sframe_with_index(style)},
{"num_content_images", content.size()}});

m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
// TODO: change to include random seed.
if (pretrained_weights) {
m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
} else {
m_resnet_spec = init_resnet(num_styles, random_seed);
}

m_vgg_spec = init_vgg_16(vgg_mlmodel_path);

float_array_map weight_params = m_resnet_spec->export_params_view();
Expand Down
62 changes: 42 additions & 20 deletions src/toolkits/style_transfer/style_transfer_model_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <toolkits/style_transfer/style_transfer_model_definition.hpp>

#include <random>

#include <ml/neural_net/weight_init.hpp>
#include <toolkits/coreml_export/mlmodel_include.hpp>

Expand All @@ -20,14 +22,33 @@ using CoreML::Specification::NeuralNetworkLayer;
using turi::neural_net::float_array_map;
using turi::neural_net::model_spec;
using turi::neural_net::scalar_weight_initializer;
using turi::neural_net::uniform_weight_initializer;
using turi::neural_net::weight_initializer;
using turi::neural_net::zero_weight_initializer;


using padding_type = model_spec::padding_type;

namespace {

constexpr float LOWER_BOUND = -0.07;
constexpr float UPPER_BOUND = 0.07;

// TODO: refactor code to be more readable with loops
void define_resnet(model_spec& nn_spec, size_t num_styles) {
void define_resnet(model_spec& nn_spec, size_t num_styles, bool initialize=false, int random_seed=0) {
std::mt19937 random_engine;
std::seed_seq seed_seq{random_seed};
random_engine = std::mt19937(seed_seq);

weight_initializer initializer;

// This is to make sure that when the uniform initialization is not needed extra work is avoided
if (initialize) {
initializer = uniform_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe &random_engine will become a dangling pointer once random_engine goes out scope after this line. I suspect you will need to move random_engine outside the scope of this if statement to obtain well-defined behavior

Copy link
Contributor

@shreyajain17 shreyajain17 Feb 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickjong

What is the relationship between this PR and #2874 ?

This fixes the waviness users were experiencing. The pre-trained model let's the users get faster models but this is more likely to introduce artifacts into the stylization. The uniform initialization let's the user train a higher quality model but with a higher number of iterations. It resolves (7.) of the issue

Are we exposing both pre trained weights as well as just the uniform initialization. I think if we are exposing both, it would be good for the users to actually have this literature somewhere. I don't think it is intuitive and I think we should put it in our userguide.

I also think maybe in another PR you should expose the random_seed parameter. I think it would be great if users could generate reproducible models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose trying to have all the deep-learning toolkits controlled by random seeds for 6.2...

} else {
initializer = zero_weight_initializer();
}

nn_spec.add_padding(
/* name */ "transformer_pad0",
/* input */ "image",
Expand All @@ -46,7 +67,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_1_inst_gamma",
Expand Down Expand Up @@ -102,7 +123,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_2_inst_gamma",
Expand Down Expand Up @@ -158,7 +179,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_3_inst_gamma",
Expand Down Expand Up @@ -214,7 +235,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_1_gamma",
Expand Down Expand Up @@ -270,7 +291,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_2_gamma",
Expand Down Expand Up @@ -327,7 +348,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_1_gamma",
Expand Down Expand Up @@ -383,7 +404,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_2_gamma",
Expand Down Expand Up @@ -440,7 +461,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_1_gamma",
Expand Down Expand Up @@ -496,7 +517,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_2_gamma",
Expand Down Expand Up @@ -553,7 +574,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_1_gamma",
Expand Down Expand Up @@ -609,7 +630,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_2_gamma",
Expand Down Expand Up @@ -666,7 +687,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_1_gamma",
Expand Down Expand Up @@ -722,7 +743,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_2_gamma",
Expand Down Expand Up @@ -785,7 +806,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_1_inst_gamma",
Expand Down Expand Up @@ -847,7 +868,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_2_inst_gamma",
Expand Down Expand Up @@ -903,7 +924,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_instancenorm5_gamma",
Expand Down Expand Up @@ -1164,9 +1185,10 @@ std::unique_ptr<model_spec> init_resnet(const std::string& path) {
return spec;
}

std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles) {
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed) {
std::unique_ptr<model_spec> nn_spec(new model_spec());
define_resnet(*nn_spec, num_styles);
define_resnet(*nn_spec, num_styles, /* initialize */ true, random_seed);
return nn_spec;
}

Expand All @@ -1190,4 +1212,4 @@ std::unique_ptr<model_spec> init_vgg_16(const std::string& path) {
}

} // namespace style_transfer
} // namespace turi
} // namespace turi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace turi {
namespace style_transfer {

std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed=0);
std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path,
size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_vgg_16();
Expand Down