Skip to content

Commit 9c82a22

Browse files
committed
fixes confusing source of value max for channel input normalization
1 parent c4771cb commit 9c82a22

File tree

6 files changed

+15
-10
lines changed

6 files changed

+15
-10
lines changed

examples/generative/fashionmnist_autoencoder.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ defmodule FashionMNIST do
1111
@batch_size 32
1212
@image_channels 1
1313
@image_side_pixels 28
14+
@channel_value_max 255
1415

1516
defmodule Autoencoder do
1617
@image_channels 1
@@ -39,7 +40,7 @@ defmodule FashionMNIST do
3940
bin
4041
|> Nx.from_binary(type)
4142
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
42-
|> Nx.divide(Nx.Constants.max(type))
43+
|> Nx.divide(@channel_value_max)
4344
|> Nx.to_batched(@batch_size)
4445
end
4546

examples/generative/mnist_gan.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ defmodule MNISTGAN do
1616
@batch_size 32
1717
@image_channels 1
1818
@image_side_pixels 28
19+
@channel_value_max 255
1920

2021
defp transform_images({bin, type, shape}) do
2122
bin
2223
|> Nx.from_binary(type)
2324
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
24-
|> Nx.divide(Nx.Constants.max(type))
25+
|> Nx.divide(@channel_value_max)
2526
|> Nx.to_batched(@batch_size)
2627
end
2728

examples/vision/cifar10.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ defmodule Cifar do
1111
@batch_size 32
1212
@image_channels 3
1313
@image_side_pixels 32
14+
@channel_value_max 255
1415
@label_values Enum.to_list(0..9)
1516

1617
defp transform_images({bin, type, shape}) do
1718
bin
1819
|> Nx.from_binary(type)
1920
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
20-
|> Nx.divide(Nx.Constants.max(type))
21+
|> Nx.divide(@channel_value_max)
2122
|> Nx.to_batched(@batch_size)
2223
|> Enum.split(1500)
2324
end

examples/vision/cnn_image_denoising.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ defmodule MnistDenoising do
1414
@epochs 25
1515
@image_channels 1
1616
@image_side_pixels 28
17+
@channel_value_max 255
1718

1819
def run do
1920
{images, _} = Scidata.MNIST.download()
@@ -49,7 +50,7 @@ defmodule MnistDenoising do
4950
bin
5051
|> Nx.from_binary(type)
5152
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
52-
|> Nx.divide(Nx.Constants.max(type))
53+
|> Nx.divide(@channel_value_max)
5354
|> Nx.to_batched_list(@batch_size)
5455
# Test split
5556
|> Enum.split(1750)

examples/vision/horses_or_humans.exs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ defmodule HorsesOrHumans do
1818
@batch_size 32
1919
@image_channels 4
2020
@image_side_pixels 300
21-
@input_type {:u, 8}
21+
@channel_value_max 255
2222

2323
def data() do
2424
Path.wildcard(@directories)
@@ -31,9 +31,9 @@ defmodule HorsesOrHumans do
3131
|> Stream.cycle()
3232
end
3333

34-
defnp augment(inp) do
34+
defnp augment(%Nx.Tensor{type: type} = inp) do
3535
# Normalize
36-
inp = inp / Nx.Constants.max(@input_type)
36+
inp = inp / @channel_value_max
3737

3838
# For now just a random flip
3939
if Nx.random_uniform({}) > 0.5 do
@@ -46,8 +46,8 @@ defmodule HorsesOrHumans do
4646
defp parse_png(filename) do
4747
class =
4848
if String.contains?(filename, "horses"),
49-
do: Nx.tensor([1, 0], type: @input_type),
50-
else: Nx.tensor([0, 1], type: @input_type)
49+
do: Nx.tensor([1, 0], type: {:u, 8}),
50+
else: Nx.tensor([0, 1], type: {:u, 8})
5151

5252
{:ok, img} = StbImage.read_file(filename)
5353

examples/vision/mnist.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ defmodule Mnist do
1111

1212
@batch_size 32
1313
@image_side_pixels 28
14+
@channel_value_max 255
1415
@label_values Enum.to_list(0..9)
1516

1617
defp transform_images({bin, type, shape}) do
1718
bin
1819
|> Nx.from_binary(type)
1920
|> Nx.reshape({elem(shape, 0), @image_side_pixels**2})
20-
|> Nx.divide(Nx.Constants.max(type))
21+
|> Nx.divide(@channel_value_max)
2122
|> Nx.to_batched(@batch_size)
2223
# Test split
2324
|> Enum.split(1750)

0 commit comments

Comments
 (0)