diff --git a/src/ml/neural_net/weight_init.cpp b/src/ml/neural_net/weight_init.cpp index 8a108fb4f6..1511a20a68 100644 --- a/src/ml/neural_net/weight_init.cpp +++ b/src/ml/neural_net/weight_init.cpp @@ -35,6 +35,19 @@ void xavier_weight_initializer::operator()(float* first_weight, } } +uniform_weight_initializer::uniform_weight_initializer( + float lower_bound, float upper_bound, std::mt19937* random_engine) + : dist_(std::uniform_real_distribution(lower_bound, upper_bound)), + random_engine_(*random_engine) +{} + +void uniform_weight_initializer::operator()(float* first_weight, + float* last_weight) { + for (float* w = first_weight; w != last_weight; ++w) { + *w = dist_(random_engine_); + } +} + scalar_weight_initializer::scalar_weight_initializer(float scalar) : scalar_(scalar) {} diff --git a/src/ml/neural_net/weight_init.hpp b/src/ml/neural_net/weight_init.hpp index 94aeff9fff..d85ba66e18 100644 --- a/src/ml/neural_net/weight_init.hpp +++ b/src/ml/neural_net/weight_init.hpp @@ -48,6 +48,32 @@ class xavier_weight_initializer { std::mt19937& random_engine_; }; + +class uniform_weight_initializer { + public: + + /** + * Creates a weight initializer that performs Uniform initialization + * + * \param lower_bound The lower bound of the uniform distribution to be sampled + * \param upper_bound The upper bound of the uniform distribution to be sampled + * \param random_engine The random number generator to use, which must remain + * valid for the lifetime of this instance. + */ + uniform_weight_initializer(float lower_bound, float upper_bound, + std::mt19937* random_engine); + + /** + * Initializes each value in uniformly at random in the range [-lower_bound, upper_bound] + */ + void operator()(float* first_weight, float* last_weight); + +private: + + std::uniform_real_distribution dist_; + std::mt19937& random_engine_; +}; + struct scalar_weight_initializer { /** * Creates a weight initializer that initializes all of the weights to a diff --git a/src/python/turicreate/toolkits/style_transfer/style_transfer.py b/src/python/turicreate/toolkits/style_transfer/style_transfer.py index 3350beb392..1452e61d6c 100644 --- a/src/python/turicreate/toolkits/style_transfer/style_transfer.py +++ b/src/python/turicreate/toolkits/style_transfer/style_transfer.py @@ -247,6 +247,7 @@ def create( options["num_styles"] = len(style_dataset) options["resnet_mlmodel_path"] = pretrained_resnet_model.get_model_path("coreml") options["vgg_mlmodel_path"] = pretrained_vgg16_model.get_model_path("coreml") + options["pretrained_weights"] = params["pretrained_weights"] model.train(style_dataset[style_feature], content_dataset[content_feature], options) return StyleTransfer(model_proxy=model, name=name) diff --git a/src/toolkits/style_transfer/style_transfer.cpp b/src/toolkits/style_transfer/style_transfer.cpp index c7d6b1b13d..d0cda303c0 100644 --- a/src/toolkits/style_transfer/style_transfer.cpp +++ b/src/toolkits/style_transfer/style_transfer.cpp @@ -686,6 +686,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content, } size_t num_styles = num_styles_iter->second; + auto pretrained_weights_iter = opts.find("pretrained_weights"); + bool pretrained_weights = false; + if (pretrained_weights_iter != opts.end()) { + pretrained_weights = pretrained_weights_iter->second; + } + opts.erase(pretrained_weights_iter); + init_options(opts); if (read_state("random_seed") == FLEX_UNDEFINED) { @@ -694,9 +701,11 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content, add_or_update_state({{"random_seed", random_seed}}); } + int random_seed = read_state("random_seed"); + m_training_data_iterator = create_iterator(content, style, /* repeat */ true, - /* training */ true, static_cast(num_styles)); + /* training */ true, random_seed); m_training_compute_context = create_compute_context(); if (m_training_compute_context == nullptr) { @@ -709,7 +718,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content, {"styles", style_sframe_with_index(style)}, {"num_content_images", content.size()}}); - m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles); + // TODO: change to include random seed. + if (pretrained_weights) { + m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles); + } else { + m_resnet_spec = init_resnet(num_styles, random_seed); + } + m_vgg_spec = init_vgg_16(vgg_mlmodel_path); float_array_map weight_params = m_resnet_spec->export_params_view(); diff --git a/src/toolkits/style_transfer/style_transfer_model_definition.cpp b/src/toolkits/style_transfer/style_transfer_model_definition.cpp index 1725b6922d..05740374b5 100644 --- a/src/toolkits/style_transfer/style_transfer_model_definition.cpp +++ b/src/toolkits/style_transfer/style_transfer_model_definition.cpp @@ -7,6 +7,8 @@ #include +#include + #include #include @@ -20,14 +22,33 @@ using CoreML::Specification::NeuralNetworkLayer; using turi::neural_net::float_array_map; using turi::neural_net::model_spec; using turi::neural_net::scalar_weight_initializer; +using turi::neural_net::uniform_weight_initializer; +using turi::neural_net::weight_initializer; using turi::neural_net::zero_weight_initializer; + using padding_type = model_spec::padding_type; namespace { +constexpr float LOWER_BOUND = -0.07; +constexpr float UPPER_BOUND = 0.07; + // TODO: refactor code to be more readable with loops -void define_resnet(model_spec& nn_spec, size_t num_styles) { +void define_resnet(model_spec& nn_spec, size_t num_styles, bool initialize=false, int random_seed=0) { + std::mt19937 random_engine; + std::seed_seq seed_seq{random_seed}; + random_engine = std::mt19937(seed_seq); + + weight_initializer initializer; + + // This is to make sure that when the uniform initialization is not needed extra work is avoided + if (initialize) { + initializer = uniform_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine); + } else { + initializer = zero_weight_initializer(); + } + nn_spec.add_padding( /* name */ "transformer_pad0", /* input */ "image", @@ -46,7 +67,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_encode_1_inst_gamma", @@ -102,7 +123,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 2, /* stride_width */ 2, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_encode_2_inst_gamma", @@ -158,7 +179,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 2, /* stride_width */ 2, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_encode_3_inst_gamma", @@ -214,7 +235,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_1_inst_1_gamma", @@ -270,7 +291,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_1_inst_2_gamma", @@ -327,7 +348,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_2_inst_1_gamma", @@ -383,7 +404,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_2_inst_2_gamma", @@ -440,7 +461,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_3_inst_1_gamma", @@ -496,7 +517,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_3_inst_2_gamma", @@ -553,7 +574,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_4_inst_1_gamma", @@ -609,7 +630,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_4_inst_2_gamma", @@ -666,7 +687,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_5_inst_1_gamma", @@ -722,7 +743,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_residual_5_inst_2_gamma", @@ -785,7 +806,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_decoding_1_inst_gamma", @@ -847,7 +868,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_decoding_2_inst_gamma", @@ -903,7 +924,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { /* stride_height */ 1, /* stride_width */ 1, /* padding */ padding_type::VALID, - /* weight_init_fn */ zero_weight_initializer()); + /* weight_init_fn */ initializer); nn_spec.add_inner_product( /* name */ "transformer_instancenorm5_gamma", @@ -1164,9 +1185,10 @@ std::unique_ptr init_resnet(const std::string& path) { return spec; } -std::unique_ptr init_resnet(size_t num_styles) { +std::unique_ptr init_resnet(size_t num_styles, + int random_seed) { std::unique_ptr nn_spec(new model_spec()); - define_resnet(*nn_spec, num_styles); + define_resnet(*nn_spec, num_styles, /* initialize */ true, random_seed); return nn_spec; } @@ -1190,4 +1212,4 @@ std::unique_ptr init_vgg_16(const std::string& path) { } } // namespace style_transfer -} // namespace turi \ No newline at end of file +} // namespace turi diff --git a/src/toolkits/style_transfer/style_transfer_model_definition.hpp b/src/toolkits/style_transfer/style_transfer_model_definition.hpp index d3a5ccd417..267b29b5ba 100644 --- a/src/toolkits/style_transfer/style_transfer_model_definition.hpp +++ b/src/toolkits/style_transfer/style_transfer_model_definition.hpp @@ -17,7 +17,8 @@ namespace turi { namespace style_transfer { std::unique_ptr init_resnet(const std::string& path); -std::unique_ptr init_resnet(size_t num_styles); +std::unique_ptr init_resnet(size_t num_styles, + int random_seed=0); std::unique_ptr init_resnet(const std::string& path, size_t num_styles); std::unique_ptr init_vgg_16();