Skip to content

Commit cd75977

Browse files
authored
refactor(examples): attempt to improve clarity of where some hard-coded numbers coming from or mean (#512)
1 parent 8eb6f9b commit cd75977

File tree

8 files changed

+71
-39
lines changed

8 files changed

+71
-39
lines changed

examples/basics/multi_input_example.exs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Mix.install([
77
defmodule XOR do
88
require Axon
99

10+
@batch_size 32
11+
1012
defp build_model(input_shape1, input_shape2) do
1113
inp1 = Axon.input("x1", shape: input_shape1)
1214
inp2 = Axon.input("x2", shape: input_shape2)
@@ -18,8 +20,8 @@ defmodule XOR do
1820
end
1921

2022
defp batch do
21-
x1 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)])
22-
x2 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)])
23+
x1 = Nx.tensor(for _ <- 1..@batch_size, do: [Enum.random(0..1)])
24+
x2 = Nx.tensor(for _ <- 1..@batch_size, do: [Enum.random(0..1)])
2325
y = Nx.logical_xor(x1, x2)
2426
{%{"x1" => x1, "x2" => x2}, y}
2527
end

examples/basics/multi_output_example.exs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Mix.install([
77
defmodule Power do
88
require Axon
99

10+
@batch_size 32
11+
1012
defp build_model do
1113
fc =
1214
Axon.input("input", shape: {nil, 1})
@@ -34,8 +36,7 @@ defmodule Power do
3436
Stream.unfold(
3537
Nx.Random.key(:erlang.system_time()),
3638
fn key ->
37-
# Batch size of 32
38-
{x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {32, 1}, type: {:f, 32})
39+
{x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {@batch_size, 1}, type: {:f, 32})
3940
{{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}, next_key}
4041
end
4142
)

examples/generative/fashionmnist_autoencoder.exs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@ Mix.install([
88
defmodule FashionMNIST do
99
require Axon
1010

11+
@batch_size 32
12+
@image_channels 1
13+
@image_side_pixels 28
14+
@channel_value_max 255
15+
1116
defmodule Autoencoder do
17+
@image_channels 1
18+
@image_side_pixels 28
19+
1220
defp encoder(x, latent_dim) do
1321
x
1422
|> Axon.flatten()
@@ -17,8 +25,8 @@ defmodule FashionMNIST do
1725

1826
defp decoder(x) do
1927
x
20-
|> Axon.dense(784, activation: :sigmoid)
21-
|> Axon.reshape({:batch, 1, 28, 28})
28+
|> Axon.dense(@image_side_pixels**2, activation: :sigmoid)
29+
|> Axon.reshape({:batch, @image_channels, @image_side_pixels, @image_side_pixels})
2230
end
2331

2432
def build_model(input_shape, latent_dim) do
@@ -31,9 +39,9 @@ defmodule FashionMNIST do
3139
defp transform_images({bin, type, shape}) do
3240
bin
3341
|> Nx.from_binary(type)
34-
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
35-
|> Nx.divide(255.0)
36-
|> Nx.to_batched(32)
42+
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
43+
|> Nx.divide(@channel_value_max)
44+
|> Nx.to_batched(@batch_size)
3745
end
3846

3947
defp train_model(model, train_images, epochs) do
@@ -48,15 +56,15 @@ defmodule FashionMNIST do
4856

4957
train_images = transform_images(images)
5058

51-
model = Autoencoder.build_model({nil, 1, 28, 28}, 64) |> IO.inspect()
59+
model = Autoencoder.build_model({nil, @image_channels, @image_side_pixels, @image_side_pixels}, 64) |> IO.inspect()
5260

5361
model_state = train_model(model, train_images, 5)
5462

5563
sample_image =
5664
train_images
5765
|> Enum.fetch!(0)
5866
|> Nx.slice_along_axis(0, 1)
59-
|> Nx.reshape({1, 1, 28, 28})
67+
|> Nx.reshape({1, @image_channels, @image_side_pixels, @image_side_pixels})
6068

6169
sample_image |> Nx.to_heatmap() |> IO.inspect()
6270

examples/generative/mnist_gan.exs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@ defmodule MNISTGAN do
1313
alias Axon.Loop.State
1414
import Nx.Defn
1515

16+
@batch_size 32
17+
@image_channels 1
18+
@image_side_pixels 28
19+
@channel_value_max 255
20+
1621
defp transform_images({bin, type, shape}) do
1722
bin
1823
|> Nx.from_binary(type)
19-
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
20-
|> Nx.divide(255.0)
21-
|> Nx.to_batched(32)
24+
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
25+
|> Nx.divide(@channel_value_max)
26+
|> Nx.to_batched(@batch_size)
2227
end
2328

2429
defp build_generator(z_dim) do
@@ -32,9 +37,9 @@ defmodule MNISTGAN do
3237
|> Axon.dense(1024)
3338
|> Axon.leaky_relu(alpha: 0.9)
3439
|> Axon.batch_norm()
35-
|> Axon.dense(784)
40+
|> Axon.dense(@image_side_pixels**2)
3641
|> Axon.tanh()
37-
|> Axon.reshape({:batch, 28, 28, 1})
42+
|> Axon.reshape({:batch, @image_side_pixels, @image_side_pixels, @image_channels})
3843
end
3944

4045
defp build_discriminator(input_shape) do
@@ -80,9 +85,9 @@ defmodule MNISTGAN do
8085
g_params = state[:generator][:model_state]
8186

8287
# Update D
83-
fake_labels = Nx.iota({32, 2}, axis: 1)
88+
fake_labels = Nx.iota({@batch_size, 2}, axis: 1)
8489
real_labels = Nx.reverse(fake_labels)
85-
{noise, random_next_key} = Nx.Random.normal(state[:random_key], shape: {32, 100})
90+
{noise, random_next_key} = Nx.Random.normal(state[:random_key], shape: {@batch_size, 100})
8691

8792
{d_loss, d_grads} =
8893
value_and_grad(d_params, fn params ->
@@ -162,7 +167,7 @@ defmodule MNISTGAN do
162167
preds = Axon.predict(model, pstate[:generator][:model_state], noise)
163168

164169
preds
165-
|> Nx.reshape({batch_size, 28, 28})
170+
|> Nx.reshape({batch_size, @image_side_pixels, @image_side_pixels})
166171
|> Nx.to_heatmap()
167172
|> IO.inspect()
168173

@@ -174,7 +179,7 @@ defmodule MNISTGAN do
174179
train_images = transform_images(images)
175180

176181
generator = build_generator(100)
177-
discriminator = build_discriminator({nil, 28, 28, 1})
182+
discriminator = build_discriminator({nil, @image_side_pixels, @image_side_pixels, @image_channels})
178183

179184
discriminator
180185
|> train_loop(generator)

examples/vision/cifar10.exs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,27 @@ Mix.install([
88
defmodule Cifar do
99
require Axon
1010

11+
@batch_size 32
12+
@channel_value_max 255
13+
@label_values Enum.to_list(0..9)
14+
1115
defp transform_images({bin, type, shape}) do
1216
bin
1317
|> Nx.from_binary(type)
1418
|> Nx.reshape(shape, names: [:count, :channels, :width, :height])
1519
# Move channels to last position to match what conv layer expects
1620
|> Nx.transpose(axes: [:count, :width, :height, :channels])
17-
|> Nx.divide(255.0)
18-
|> Nx.to_batched(32)
21+
|> Nx.divide(@channel_value_max)
22+
|> Nx.to_batched(@batch_size)
1923
|> Enum.split(1500)
2024
end
2125

2226
defp transform_labels({bin, type, _}) do
2327
bin
2428
|> Nx.from_binary(type)
2529
|> Nx.new_axis(-1)
26-
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
27-
|> Nx.to_batched(32)
30+
|> Nx.equal(Nx.tensor(@label_values))
31+
|> Nx.to_batched(@batch_size)
2832
|> Enum.split(1500)
2933
end
3034

@@ -39,7 +43,7 @@ defmodule Cifar do
3943
|> Axon.flatten()
4044
|> Axon.dense(64, activation: :relu)
4145
|> Axon.dropout(rate: 0.5)
42-
|> Axon.dense(10, activation: :softmax)
46+
|> Axon.dense(length(@label_values), activation: :softmax)
4347
end
4448

4549
defp train_model(model, train_images, train_labels, epochs) do

examples/vision/cnn_image_denoising.exs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ defmodule MnistDenoising do
1212
@noise_factor 0.4
1313
@batch_size 32
1414
@epochs 25
15+
@image_channels 1
16+
@image_side_pixels 28
17+
@channel_value_max 255
1518

1619
def run do
1720
{images, _} = Scidata.MNIST.download()
@@ -26,7 +29,7 @@ defmodule MnistDenoising do
2629
noisy_train_images |> Enum.take(1) |> hd() |> display_image()
2730

2831
# Train with noisy images as input and train images as targets
29-
model = build_model({nil, 1, 28, 28})
32+
model = build_model({nil, @image_channels, @image_side_pixels, @image_side_pixels})
3033

3134
model_state =
3235
model
@@ -46,8 +49,8 @@ defmodule MnistDenoising do
4649
defp transform_images({bin, type, shape}) do
4750
bin
4851
|> Nx.from_binary(type)
49-
|> Nx.reshape({elem(shape, 0), 28, 28, 1})
50-
|> Nx.divide(255.0)
52+
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
53+
|> Nx.divide(@channel_value_max)
5154
|> Nx.to_batched_list(@batch_size)
5255
# Test split
5356
|> Enum.split(1750)
@@ -63,7 +66,7 @@ defmodule MnistDenoising do
6366
defp display_image(images) do
6467
images
6568
|> Nx.slice_along_axis(0, 1)
66-
|> Nx.reshape({28, 28, 1})
69+
|> Nx.reshape({@image_side_pixels, @image_side_pixels, @image_channels})
6770
|> Nx.to_heatmap()
6871
|> IO.inspect()
6972
end

examples/vision/horses_or_humans.exs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ defmodule HorsesOrHumans do
1515
# or you can use Req to download and extract the zip file and iterate
1616
# over the resulting data
1717
@directories "examples/vision/{horses,humans}/*"
18+
@batch_size 32
19+
@image_channels 4
20+
@image_side_pixels 300
21+
@channel_value_max 255
1822

1923
def data() do
2024
Path.wildcard(@directories)
21-
|> Stream.chunk_every(32, 32, :discard)
25+
|> Stream.chunk_every(@batch_size, @batch_size, :discard)
2226
|> Task.async_stream(fn batch ->
2327
{inp, labels} = batch |> Enum.map(&parse_png/1) |> Enum.unzip()
2428
{Nx.stack(inp), Nx.stack(labels)}
@@ -29,7 +33,7 @@ defmodule HorsesOrHumans do
2933

3034
defnp augment(inp) do
3135
# Normalize
32-
inp = inp / 255.0
36+
inp = inp / @channel_value_max
3337

3438
# For now just a random flip
3539
if Nx.random_uniform({}) > 0.5 do
@@ -78,7 +82,7 @@ defmodule HorsesOrHumans do
7882
end
7983

8084
def run() do
81-
model = build_model({nil, 300, 300, 4}) |> IO.inspect()
85+
model = build_model({nil, @image_side_pixels, @image_side_pixels, @image_channels}) |> IO.inspect()
8286
optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4)
8387
centralized_optimizer = Polaris.Updates.compose(Polaris.Updates.centralize(), optimizer)
8488

examples/vision/mnist.exs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@ Mix.install([
99
defmodule Mnist do
1010
require Axon
1111

12+
@batch_size 32
13+
@image_side_pixels 28
14+
@channel_value_max 255
15+
@label_values Enum.to_list(0..9)
16+
1217
defp transform_images({bin, type, shape}) do
1318
bin
1419
|> Nx.from_binary(type)
15-
|> Nx.reshape({elem(shape, 0), 784})
16-
|> Nx.divide(255.0)
17-
|> Nx.to_batched(32)
20+
|> Nx.reshape({elem(shape, 0), @image_side_pixels**2})
21+
|> Nx.divide(@channel_value_max)
22+
|> Nx.to_batched(@batch_size)
1823
# Test split
1924
|> Enum.split(1750)
2025
end
@@ -23,8 +28,8 @@ defmodule Mnist do
2328
bin
2429
|> Nx.from_binary(type)
2530
|> Nx.new_axis(-1)
26-
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
27-
|> Nx.to_batched(32)
31+
|> Nx.equal(Nx.tensor(@label_values))
32+
|> Nx.to_batched(@batch_size)
2833
# Test split
2934
|> Enum.split(1750)
3035
end
@@ -33,7 +38,7 @@ defmodule Mnist do
3338
Axon.input("input", shape: input_shape)
3439
|> Axon.dense(128, activation: :relu)
3540
|> Axon.dropout()
36-
|> Axon.dense(10, activation: :softmax)
41+
|> Axon.dense(length(@label_values), activation: :softmax)
3742
end
3843

3944
defp train_model(model, train_images, train_labels, epochs) do
@@ -56,7 +61,7 @@ defmodule Mnist do
5661
{train_images, test_images} = transform_images(images)
5762
{train_labels, test_labels} = transform_labels(labels)
5863

59-
model = build_model({nil, 784}) |> IO.inspect()
64+
model = build_model({nil, @image_side_pixels**2}) |> IO.inspect()
6065

6166
IO.write("\n\n Training Model \n\n")
6267

0 commit comments

Comments
 (0)