Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ concurrency:
cancel-in-progress: true

env:
OTP_VERSION: "25.0"
ELIXIR_VERSION: "1.14.0"
OTP_VERSION: "26.1.1"
ELIXIR_VERSION: "1.15.6"

jobs:
deploy:
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ jobs:
fail-fast: false
matrix:
include:
- otp: "27.1"
elixir: "1.19.0"
lint: true
- otp: "26.1.1"
elixir: "1.15.6"
- otp: "26.1.1"
elixir: "1.15.6"
lint: true
- otp: "25.3.2.6"
elixir: "1.14.5"
- otp: "25.3.2.6"
elixir: "1.14.5"
test_command_prepend: "USE_EXLA=true"
- otp: "25.3.2.6"
elixir: "1.14.5"
- otp: "26.1.1"
elixir: "1.15.6"
test_command_prepend: "USE_TORCHX=true"
steps:
- uses: actions/checkout@v3
Expand Down
6 changes: 3 additions & 3 deletions lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ defmodule Axon.Activations do
#Nx.Tensor<
bf16[batch: 2][data: 3]
[
[7.781982421875e-4, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.3984375, 0.59765625, 0.796875]
]
>
Expand Down Expand Up @@ -249,7 +249,7 @@ defmodule Axon.Activations do
#Nx.Tensor<
bf16[batch: 2][data: 3]
[
[-7.781982421875e-4, -0.0, -0.0],
[-0.0, -0.0, -0.0],
[0.3984375, 1.1953125, 2.390625]
]
>
Expand Down Expand Up @@ -645,7 +645,7 @@ defmodule Axon.Activations do
#Nx.Tensor<
bf16[batch: 2][data: 3]
[
[-1.09375, -1.5078125, -1.6640625],
[-1.09375, -1.5, -1.65625],
[1.046875, 2.09375, 3.140625]
]
>
Expand Down
3 changes: 0 additions & 3 deletions lib/axon/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ defmodule Axon.Defn do
[fun.(vars)]
end

@impl true
def __stream__(_, _, _, _, _, _, _), do: raise("not implemented")

@impl true
def __compile__(_, _, _, _), do: raise("not implemented")

Expand Down
8 changes: 4 additions & 4 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ defmodule Axon.Loop do
final_metrics_map = loop_state.metrics
loop_state = %{loop_state | metrics: zero_metrics}

{status, final_metrics_map, state} =
{status, final_metrics_map, %State{} = state} =
case fire_event(:started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
{:halted, final_metrics_map, state}
Expand All @@ -1647,7 +1647,7 @@ defmodule Axon.Loop do
Enum.reduce_while(
epoch_start..epoch_end//1,
{batch_fn, final_metrics_map, state},
fn epoch, {batch_fn, final_metrics_map, loop_state} ->
fn epoch, {batch_fn, final_metrics_map, %State{} = loop_state} ->
case fire_event(:epoch_started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?)
Expand Down Expand Up @@ -1691,7 +1691,7 @@ defmodule Axon.Loop do
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}

{:continue, state} ->
{:continue, %State{} = state} ->
{:cont,
{batch_fn, Map.put(final_metrics_map, epoch, state.metrics),
%State{
Expand Down Expand Up @@ -1924,7 +1924,7 @@ defmodule Axon.Loop do
# Halts an epoch during looping
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
{:halt_epoch, %{epoch: epoch, metrics: metrics} = state} ->
{:halt_epoch, %State{epoch: epoch, metrics: metrics} = state} ->
final_metrics_map = Map.put(final_metrics_map, epoch, metrics)
{:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}}

Expand Down
10 changes: 6 additions & 4 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ defmodule Axon.Quantization.Layers do
end

deftransformp reshape_scales(scales, y) do
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
Nx.reshape(scales, Tuple.append(ones, :auto))
n = Nx.rank(y) - 1
ones = Tuple.duplicate(1, n)
Nx.reshape(scales, Tuple.insert_at(ones, n, :auto))
end

deftransformp reshape_output(output, x_shape) do
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
new_shape = Tuple.append(all_but_last, :auto)
n = tuple_size(x_shape) - 1
all_but_last = Tuple.delete_at(x_shape, n)
new_shape = Tuple.insert_at(all_but_last, n, :auto)
Nx.reshape(output, new_shape)
end
end
19 changes: 11 additions & 8 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ defmodule Axon.MixProject do
deps: deps(),
docs: docs(),
description: "Create and train neural networks in Elixir",
package: package(),
preferred_cli_env: [
docs: :docs,
"hex.publish": :docs
]
package: package()
]
end

def cli do
[
docs: :docs,
"hex.publish": :docs
]
end

Expand All @@ -35,9 +38,9 @@ defmodule Axon.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:nx, "~> 0.9", nx_opts()},
{:exla, "~> 0.9", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.9", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.10", nx_opts()},
{:exla, "~> 0.10", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.10", [only: :test] ++ torchx_opts()},
{:ex_doc, "~> 0.23", only: :docs},
{:table_rex, "~> 3.1 or ~> 4.1", optional: true},
{:kino, "~> 0.7", optional: true},
Expand Down
13 changes: 7 additions & 6 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
%{
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
"earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
"ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"},
"exla": {:hex, :exla, "0.9.0", "e048c7a3d33917c214774a7ea1a0c626eb9de01e3fb2423cf9e2b89ef6dada3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.8.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "cbd30b54992d0da01a5aaee361a3160fc29de05a9f6c3dbcbd1fa04b4aa72302"},
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
"fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"},
"fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"},
"kino": {:hex, :kino, "0.14.1", "c499afb1cd0be462feaf0a75c0631aa65aacc545b1c10f431b439b74f104be22", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "090aea1aaa267e42e5ac24ee6bc5ed515aecc0a9edb8619aa4ee839201e704aa"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.13", "03c00405987a2202e4b8014ee55eb7f5727691b3f13d76a3764f6eeccef45322", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "00c72bc270e7b9d3c339f726cdab0012fd3f2fc75e36c7548e0f250fe420fa10"},
Expand All @@ -12,12 +13,12 @@
"makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"},
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "4.1.0", "fbaa8b1ce154c9772012bf445bfb86b587430fb96f3b12022d3f35ee4a68c918", [:mix], [], "hexpm", "95932701df195d43bc2d1c6531178fc8338aa8f38c80f098504d529c43bc2601"},
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
"torchx": {:hex, :torchx, "0.9.0", "936cbd32233f89d73700c39b7ef56f94b3f3541db03c90f8ddf6b3fe73260e28", [:mix], [{:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "4e057d6b93fc91191957230b2c61c408861b888abdf6a900baf0db4125405505"},
"torchx": {:hex, :torchx, "0.10.2", "4b8529bfc4b0e641232497c99ef6d2508e652198840b212373333361352f0bae", [:mix], [{:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "cad541c64df8ddcbf50d9b0f212961632361a03050c8e01493f0fc8d4fed96d9"},
"vega_lite": {:hex, :vega_lite, "0.1.9", "d7a288665f916181b68d0a3617f1b3611d16a4dcd5fafb51b847b71db1159d4c", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "c6a056e763162198e73ae6dfb46c09753bb0298474410fd085074e1cdcee7418"},
"xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"},
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},
}
25 changes: 20 additions & 5 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4062,9 +4062,10 @@ defmodule CompilerTest do
}
} = params = init_fn.(input, ModelState.empty())

assert_equal(
assert_all_close(
predict_fn.(params, input),
Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b)
Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b),
atol: 1.0e-7
)
end

Expand Down Expand Up @@ -4192,7 +4193,11 @@ defmodule CompilerTest do
enc = {eik, ehk, eb}
dec = {dik, dhk, db}

assert_equal(predict_fn.(params, input), equiv_fn.(input, enc, dec))
assert_all_close(
predict_fn.(params, input),
equiv_fn.(input, enc, dec),
atol: 1.0e-7
)
end

test "initializes with use_bias false" do
Expand Down Expand Up @@ -5246,7 +5251,11 @@ defmodule CompilerTest do

input = random({1, 1})

assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
assert_all_close(
predict_fn.(params, input),
expected_predict_fn.(input, k1, b1, k2, b2),
atol: 1.0e-7
)
end

test "predicts correctly with multiple dense, used twice" do
Expand Down Expand Up @@ -5290,7 +5299,11 @@ defmodule CompilerTest do

input = random({1, 1})

assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
assert_all_close(
predict_fn.(params, input),
expected_predict_fn.(input, k1, b1, k2, b2),
atol: 1.0e-7
)
end

test "predicts correctly with multiple blocks in network" do
Expand Down Expand Up @@ -5703,6 +5716,8 @@ defmodule CompilerTest do
out =
ExUnit.CaptureIO.capture_io(fn ->
predict_fn.(model_state, input)
# Wait for async print operations to flush
Process.sleep(1000)
end)

assert out =~ "x:"
Expand Down
8 changes: 7 additions & 1 deletion test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,13 @@ defmodule Axon.LayersTest do
bias = 0.0

assert_equal(
Axon.Layers.conv_transpose(inp, kernel, bias, padding: [{0, 1}, {1, 2}], channels: :first),
Axon.Layers.conv_transpose(
inp,
kernel,
bias,
padding: [{0, 1}, {1, 2}],
channels: :first
),
Nx.tensor([[[[0.0, 2.0, 3.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]]])
)
end
Expand Down
Loading