-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Updated Initialization for Style Transfer #2988
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
|
|
||
| #include <toolkits/style_transfer/style_transfer_model_definition.hpp> | ||
|
|
||
| #include <random> | ||
|
|
||
| #include <ml/neural_net/weight_init.hpp> | ||
| #include <toolkits/coreml_export/mlmodel_include.hpp> | ||
|
|
||
|
|
@@ -20,14 +22,33 @@ using CoreML::Specification::NeuralNetworkLayer; | |
| using turi::neural_net::float_array_map; | ||
| using turi::neural_net::model_spec; | ||
| using turi::neural_net::scalar_weight_initializer; | ||
| using turi::neural_net::uniform_weight_initializer; | ||
| using turi::neural_net::weight_initializer; | ||
| using turi::neural_net::zero_weight_initializer; | ||
|
|
||
|
|
||
| using padding_type = model_spec::padding_type; | ||
|
|
||
| namespace { | ||
|
|
||
| constexpr float LOWER_BOUND = -0.07; | ||
| constexpr float UPPER_BOUND = 0.07; | ||
|
|
||
| // TODO: refactor code to be more readable with loops | ||
| void define_resnet(model_spec& nn_spec, size_t num_styles) { | ||
| void define_resnet(model_spec& nn_spec, size_t num_styles, bool initialize=false, int random_seed=0) { | ||
| std::mt19937 random_engine; | ||
| std::seed_seq seed_seq{random_seed}; | ||
| random_engine = std::mt19937(seed_seq); | ||
|
|
||
| weight_initializer initializer; | ||
|
|
||
| // This is to make sure that when the uniform initialization is not needed extra work is avoided | ||
| if (initialize) { | ||
| initializer = uniform_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Are we exposing both pre trained weights as well as just the uniform initialization. I think if we are exposing both, it would be good for the users to actually have this literature somewhere. I don't think it is intuitive and I think we should put it in our userguide. I also think maybe in another PR you should expose the random_seed parameter. I think it would be great if users could generate reproducible models.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose trying to have all the deep-learning toolkits controlled by random seeds for 6.2... |
||
| } else { | ||
| initializer = zero_weight_initializer(); | ||
| } | ||
|
|
||
| nn_spec.add_padding( | ||
| /* name */ "transformer_pad0", | ||
| /* input */ "image", | ||
|
|
@@ -46,7 +67,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
abhishekpratapa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_encode_1_inst_gamma", | ||
|
|
@@ -102,7 +123,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 2, | ||
| /* stride_width */ 2, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_encode_2_inst_gamma", | ||
|
|
@@ -158,7 +179,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 2, | ||
| /* stride_width */ 2, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_encode_3_inst_gamma", | ||
|
|
@@ -214,7 +235,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_1_inst_1_gamma", | ||
|
|
@@ -270,7 +291,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_1_inst_2_gamma", | ||
|
|
@@ -327,7 +348,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_2_inst_1_gamma", | ||
|
|
@@ -383,7 +404,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_2_inst_2_gamma", | ||
|
|
@@ -440,7 +461,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_3_inst_1_gamma", | ||
|
|
@@ -496,7 +517,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_3_inst_2_gamma", | ||
|
|
@@ -553,7 +574,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_4_inst_1_gamma", | ||
|
|
@@ -609,7 +630,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_4_inst_2_gamma", | ||
|
|
@@ -666,7 +687,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_5_inst_1_gamma", | ||
|
|
@@ -722,7 +743,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_residual_5_inst_2_gamma", | ||
|
|
@@ -785,7 +806,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_decoding_1_inst_gamma", | ||
|
|
@@ -847,7 +868,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_decoding_2_inst_gamma", | ||
|
|
@@ -903,7 +924,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) { | |
| /* stride_height */ 1, | ||
| /* stride_width */ 1, | ||
| /* padding */ padding_type::VALID, | ||
| /* weight_init_fn */ zero_weight_initializer()); | ||
| /* weight_init_fn */ initializer); | ||
|
|
||
| nn_spec.add_inner_product( | ||
| /* name */ "transformer_instancenorm5_gamma", | ||
|
|
@@ -1164,9 +1185,10 @@ std::unique_ptr<model_spec> init_resnet(const std::string& path) { | |
| return spec; | ||
| } | ||
|
|
||
| std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles) { | ||
| std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles, | ||
| int random_seed) { | ||
| std::unique_ptr<model_spec> nn_spec(new model_spec()); | ||
| define_resnet(*nn_spec, num_styles); | ||
| define_resnet(*nn_spec, num_styles, /* initialize */ true, random_seed); | ||
| return nn_spec; | ||
| } | ||
|
|
||
|
|
@@ -1190,4 +1212,4 @@ std::unique_ptr<model_spec> init_vgg_16(const std::string& path) { | |
| } | ||
|
|
||
| } // namespace style_transfer | ||
| } // namespace turi | ||
| } // namespace turi | ||
Uh oh!
There was an error while loading. Please reload this page.