Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def create(
options["num_styles"] = len(style_dataset)
options["resnet_mlmodel_path"] = pretrained_resnet_model.get_model_path("coreml")
options["vgg_mlmodel_path"] = pretrained_vgg16_model.get_model_path("coreml")
options["pretrained_weights"] = params["pretrained_weights"]

model.train(style_dataset[style_feature], content_dataset[content_feature], options)
return StyleTransfer(model_proxy=model, name=name)
Expand Down
18 changes: 16 additions & 2 deletions src/toolkits/style_transfer/style_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,12 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
}
size_t num_styles = num_styles_iter->second;

auto pretrained_weights_iter = opts.find("pretrained_weights");
bool pretrained_weights = false;
if (pretrained_weights_iter != opts.end()) {
pretrained_weights = pretrained_weights_iter->second;
}

init_options(opts);

if (read_state<flexible_type>("random_seed") == FLEX_UNDEFINED) {
Expand All @@ -694,9 +700,11 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
add_or_update_state({{"random_seed", random_seed}});
}

int random_seed = read_state<int>("random_seed");

m_training_data_iterator =
create_iterator(content, style, /* repeat */ true,
/* training */ true, static_cast<int>(num_styles));
/* training */ true, random_seed);

m_training_compute_context = create_compute_context();
if (m_training_compute_context == nullptr) {
Expand All @@ -709,7 +717,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
{"styles", style_sframe_with_index(style)},
{"num_content_images", content.size()}});

m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
// TODO: change to include random seed.
if (pretrained_weights) {
m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
} else {
m_resnet_spec = init_resnet(num_styles, random_seed);
}

m_vgg_spec = init_vgg_16(vgg_mlmodel_path);

float_array_map weight_params = m_resnet_spec->export_params_view();
Expand Down
51 changes: 32 additions & 19 deletions src/toolkits/style_transfer/style_transfer_model_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <toolkits/style_transfer/style_transfer_model_definition.hpp>

#include <random>

#include <ml/neural_net/weight_init.hpp>
#include <toolkits/coreml_export/mlmodel_include.hpp>

Expand All @@ -20,14 +22,24 @@ using CoreML::Specification::NeuralNetworkLayer;
using turi::neural_net::float_array_map;
using turi::neural_net::model_spec;
using turi::neural_net::scalar_weight_initializer;
using turi::neural_net::xavier_weight_initializer;
using turi::neural_net::zero_weight_initializer;

using padding_type = model_spec::padding_type;

namespace {

constexpr float LOWER_BOUND = -0.7;
constexpr float UPPER_BOUND = 0.7;

// TODO: refactor code to be more readable with loops
void define_resnet(model_spec& nn_spec, size_t num_styles) {
void define_resnet(model_spec& nn_spec, size_t num_styles, int random_seed=0) {
std::mt19937 random_engine;
std::seed_seq seed_seq{random_seed};
random_engine = std::mt19937(seed_seq);

auto initializer = xavier_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine);

nn_spec.add_padding(
/* name */ "transformer_pad0",
/* input */ "image",
Expand All @@ -46,7 +58,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_1_inst_gamma",
Expand Down Expand Up @@ -102,7 +114,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_2_inst_gamma",
Expand Down Expand Up @@ -158,7 +170,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_3_inst_gamma",
Expand Down Expand Up @@ -214,7 +226,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_1_gamma",
Expand Down Expand Up @@ -270,7 +282,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_2_gamma",
Expand Down Expand Up @@ -327,7 +339,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_1_gamma",
Expand Down Expand Up @@ -383,7 +395,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_2_gamma",
Expand Down Expand Up @@ -440,7 +452,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_1_gamma",
Expand Down Expand Up @@ -496,7 +508,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_2_gamma",
Expand Down Expand Up @@ -553,7 +565,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_1_gamma",
Expand Down Expand Up @@ -609,7 +621,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_2_gamma",
Expand Down Expand Up @@ -666,7 +678,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_1_gamma",
Expand Down Expand Up @@ -722,7 +734,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_2_gamma",
Expand Down Expand Up @@ -785,7 +797,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_1_inst_gamma",
Expand Down Expand Up @@ -847,7 +859,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_2_inst_gamma",
Expand Down Expand Up @@ -903,7 +915,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_instancenorm5_gamma",
Expand Down Expand Up @@ -1164,9 +1176,10 @@ std::unique_ptr<model_spec> init_resnet(const std::string& path) {
return spec;
}

std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles) {
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed) {
std::unique_ptr<model_spec> nn_spec(new model_spec());
define_resnet(*nn_spec, num_styles);
define_resnet(*nn_spec, num_styles, random_seed);
return nn_spec;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace turi {
namespace style_transfer {

std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed=0);
std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path,
size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_vgg_16();
Expand Down