@@ -13,12 +13,17 @@ defmodule MNISTGAN do
1313 alias Axon.Loop.State
1414 import Nx.Defn
1515
16+ @ batch_size 32
17+ @ image_channels 1
18+ @ image_side_pixels 28
19+ @ channel_value_max 255
20+
1621 defp transform_images ( { bin , type , shape } ) do
1722 bin
1823 |> Nx . from_binary ( type )
19- |> Nx . reshape ( { elem ( shape , 0 ) , 1 , 28 , 28 } )
20- |> Nx . divide ( 255.0 )
21- |> Nx . to_batched ( 32 )
24+ |> Nx . reshape ( { elem ( shape , 0 ) , @ image_channels , @ image_side_pixels , @ image_side_pixels } )
25+ |> Nx . divide ( @ channel_value_max )
26+ |> Nx . to_batched ( @ batch_size )
2227 end
2328
2429 defp build_generator ( z_dim ) do
@@ -32,9 +37,9 @@ defmodule MNISTGAN do
3237 |> Axon . dense ( 1024 )
3338 |> Axon . leaky_relu ( alpha: 0.9 )
3439 |> Axon . batch_norm ( )
35- |> Axon . dense ( 784 )
40+ |> Axon . dense ( @ image_side_pixels ** 2 )
3641 |> Axon . tanh ( )
37- |> Axon . reshape ( { :batch , 28 , 28 , 1 } )
42+ |> Axon . reshape ( { :batch , @ image_side_pixels , @ image_side_pixels , @ image_channels } )
3843 end
3944
4045 defp build_discriminator ( input_shape ) do
@@ -80,9 +85,9 @@ defmodule MNISTGAN do
8085 g_params = state [ :generator ] [ :model_state ]
8186
8287 # Update D
83- fake_labels = Nx . iota ( { 32 , 2 } , axis: 1 )
88+ fake_labels = Nx . iota ( { @ batch_size , 2 } , axis: 1 )
8489 real_labels = Nx . reverse ( fake_labels )
85- { noise , random_next_key } = Nx.Random . normal ( state [ :random_key ] , shape: { 32 , 100 } )
90+ { noise , random_next_key } = Nx.Random . normal ( state [ :random_key ] , shape: { @ batch_size , 100 } )
8691
8792 { d_loss , d_grads } =
8893 value_and_grad ( d_params , fn params ->
@@ -162,7 +167,7 @@ defmodule MNISTGAN do
162167 preds = Axon . predict ( model , pstate [ :generator ] [ :model_state ] , noise )
163168
164169 preds
165- |> Nx . reshape ( { batch_size , 28 , 28 } )
170+ |> Nx . reshape ( { batch_size , @ image_side_pixels , @ image_side_pixels } )
166171 |> Nx . to_heatmap ( )
167172 |> IO . inspect ( )
168173
@@ -174,7 +179,7 @@ defmodule MNISTGAN do
174179 train_images = transform_images ( images )
175180
176181 generator = build_generator ( 100 )
177- discriminator = build_discriminator ( { nil , 28 , 28 , 1 } )
182+ discriminator = build_discriminator ( { nil , @ image_side_pixels , @ image_side_pixels , @ image_channels } )
178183
179184 discriminator
180185 |> train_loop ( generator )
0 commit comments