Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
# Changelog

## v0.5.0 (2022-02-16)

### Enhancements

* Bump Nx dependency
* Update documentation to account for channels last default
* Improve error message in compilation/build errors for models
* Remove deprecated `transform`

### Deprecations

* Deprecate `Axon.Loop.handle/4`

## v0.4.1 (2022-01-21)

### Bug Fixes

* Fixed a shape mismatch when training with certain optimizers

## v0.4.0 (2022-01-19)

### Enhancements
Expand Down
3 changes: 3 additions & 0 deletions lib/axon/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ defmodule Axon.Defn do

@impl true
def __compile__(_, _, _, _), do: raise("not implemented")

@impl true
def __partitions_options__(_), do: raise("not implemented")
end
12 changes: 2 additions & 10 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ defmodule Axon.Loop do
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]

You can attach event handlers to events using `Axon.Loop.handle_event/4`:
Expand Down Expand Up @@ -229,9 +227,7 @@ defmodule Axon.Loop do
:iteration_started,
:iteration_completed,
:epoch_completed,
:epoch_halted,
:halted,
:completed
:epoch_halted
]

@default_handlers %{
Expand Down Expand Up @@ -896,8 +892,6 @@ defmodule Axon.Loop do
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]

Generally, event handlers are side-effecting operations which provide some
Expand Down Expand Up @@ -1066,7 +1060,6 @@ defmodule Axon.Loop do

metrics =
Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end)
|> log(fn _ -> "\n" end, event: :completed)
|> run(validation_data, model_state)
|> Access.get(0)
|> Map.new(fn {k, v} ->
Expand Down Expand Up @@ -1733,8 +1726,7 @@ defmodule Axon.Loop do
end
end

{_, state} = fire_event(status, handler_fns, state, debug?)
state = %State{state | metrics: final_metrics}
state = %State{state | metrics: final_metrics, status: status}

output_transform.(state)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/axon/loop/state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ defmodule Axon.Loop.State do

`event_counts` is a metadata field which stores information about the number
of times each event has been fired. This is useful when creating custom filters.

`status` refers to the loop state status after the loop has executed. You can
use this to determine if the loop ran to completion or if it was halted early.
"""
@enforce_keys [:step_state]
defstruct [
:step_state,
:status,
handler_metadata: %{},
epoch: 0,
max_epoch: 1,
Expand Down
14 changes: 7 additions & 7 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Axon.MixProject do
use Mix.Project

@source_url "https://github.com/elixir-nx/axon"
@version "0.4.1"
@version "0.5.0"

def project do
[
Expand Down Expand Up @@ -35,9 +35,9 @@ defmodule Axon.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:exla, "~> 0.4.0", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.4.0", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.4.0", nx_opts()},
{:exla, "~> 0.5.0", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.5.0", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.5.0", nx_opts()},
{:ex_doc, "~> 0.23", only: :docs},
{:table_rex, "~> 3.1.1", optional: true},
{:kino, "~> 0.7", optional: true},
Expand All @@ -57,23 +57,23 @@ defmodule Axon.MixProject do
if path = System.get_env("AXON_NX_PATH") do
[path: path, override: true]
else
[github: "elixir-nx/nx", sparse: "nx", override: true]
[]
end
end

defp exla_opts do
if path = System.get_env("AXON_EXLA_PATH") do
[path: path]
else
[github: "elixir-nx/nx", sparse: "exla", override: true]
[]
end
end

defp torchx_opts do
if path = System.get_env("AXON_TORCHX_PATH") do
[path: path]
else
[github: "elixir-nx/nx", sparse: "torchx", override: true]
[]
end
end

Expand Down
8 changes: 4 additions & 4 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
%{
"castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"},
"complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"},
"earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"},
"elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"},
"ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "exla"]},
"exla": {:hex, :exla, "0.5.0", "a002cb70e59c26d4ec78a256489e4026c428ff4917f25d266e6a86c58636dc7f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "9219366cb0ea18c421349b8e0f130d85e83d7404df8054e5af6e18a47540c886"},
"kino": {:hex, :kino, "0.8.1", "da3b2cba121b7542146cffdb8af055fa0129395fa67aead9e7e3df93aed1f107", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "da45dd141db30db18973de0e3398bda3ab8cb0b5da58d6a0debbe5b864aba295"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "nx"]},
"nx": {:hex, :nx, "0.5.0", "c5e62e82606ff372d986e72cce505c98421bb4305ce9cc8e439fe6cc1966c6ad", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b29c246318181c3ebfcf0f230a0d33783ac4c92dfa34ca3aa5b9b38ae58c187e"},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "torchx"]},
"torchx": {:hex, :torchx, "0.5.0", "d787ea5a62f299a93c03a7a9f1d0d903dd854797e8fc27bbbee984d8e3e6acf1", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "832205d22259011930231e5203cc1b929136a3ad1b160e1f4690d35dfb11ddbd"},
"vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"},
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
}
69 changes: 9 additions & 60 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ defmodule Axon.LoopTest do
Axon.input("input", shape: {nil, 1})
|> Axon.dense(1)
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
|> Loop.handle(
|> Loop.handle_event(
:epoch_completed,
fn %State{step_state: pstate} = state ->
{
Expand All @@ -376,14 +376,6 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert 4 = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand All @@ -396,7 +388,7 @@ defmodule Axon.LoopTest do
Axon.input("input", shape: {nil, 1})
|> Axon.dense(1)
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
|> Loop.handle(
|> Loop.handle_event(
:epoch_completed,
fn %State{step_state: pstate} = state ->
{
Expand All @@ -416,14 +408,6 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert {{4}, 4} = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand Down Expand Up @@ -477,7 +461,7 @@ defmodule Axon.LoopTest do
end

def send_handler(loop, event) do
Axon.Loop.handle(loop, event, fn state ->
Axon.Loop.handle_event(loop, event, fn state ->
send(self(), event)
{:continue, state}
end)
Expand Down Expand Up @@ -540,15 +524,6 @@ defmodule Axon.LoopTest do
refute_received :iteration_completed
end

test "fires correctly on :completed" do
ExUnit.CaptureIO.capture_io(fn ->
run_dummy_loop!(:completed, 5, 10)
end)

assert_received :completed
refute_received :completed
end

test "fires correctly on :epoch_halted" do
model = Axon.input("foo")

Expand All @@ -562,7 +537,7 @@ defmodule Axon.LoopTest do
ExUnit.CaptureIO.capture_io(fn ->
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.handle(:iteration_completed, fn state ->
|> Axon.Loop.handle_event(:iteration_completed, fn state ->
{:halt_epoch, state}
end)
|> send_handler(:epoch_halted)
Expand All @@ -576,30 +551,6 @@ defmodule Axon.LoopTest do
refute_received :epoch_halted
end

test "fires correctly on :halted" do
model = Axon.input("foo")

data =
Stream.repeatedly(fn ->
xs = Nx.tensor([[Enum.random(0..10)]])
ys = Nx.greater(xs, 5)
{xs, ys}
end)

ExUnit.CaptureIO.capture_io(fn ->
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.handle(:iteration_completed, fn state ->
{:halt_loop, state}
end)
|> send_handler(:halted)
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 10)
end)

assert_received :halted
refute_received :halted
end

test "events fire in order" do
model = Axon.input("foo")

Expand All @@ -618,7 +569,6 @@ defmodule Axon.LoopTest do
|> send_handler(:iteration_started)
|> send_handler(:iteration_completed)
|> send_handler(:epoch_completed)
|> send_handler(:completed)
|> Axon.Loop.run(data, %{}, epochs: 1, iterations: 1)
end)

Expand All @@ -627,7 +577,6 @@ defmodule Axon.LoopTest do
assert_received :iteration_started
assert_received :iteration_completed
assert_received :epoch_completed
assert_received :completed

refute_received _
end
Expand All @@ -651,7 +600,7 @@ defmodule Axon.LoopTest do
end

def send_handler(loop, event, filter) do
Axon.Loop.handle(
Axon.Loop.handle_event(
loop,
event,
fn state ->
Expand Down Expand Up @@ -863,7 +812,7 @@ defmodule Axon.LoopTest do
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.from_state(state1)
|> Axon.Loop.handle(:epoch_completed, fn %{epoch: epoch} = state ->
|> Axon.Loop.handle_event(:epoch_completed, fn %{epoch: epoch} = state ->
assert epoch >= 3
{:continue, state}
end)
Expand All @@ -888,7 +837,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{metrics: metrics} = state ->
assert Map.has_key?(metrics, "validation_accuracy")
Expand Down Expand Up @@ -918,7 +867,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.early_stop("validation_accuracy", mode: :max)
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{handler_metadata: meta} = state ->
assert %{early_stop: %{"validation_accuracy" => _, :since_last_improvement => _}} =
Expand Down Expand Up @@ -1006,7 +955,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.reduce_lr_on_plateau("validation_accuracy", mode: :max)
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{handler_metadata: meta} = state ->
assert %{reduce_lr: %{"validation_accuracy" => _, :since_last_improvement => _}} =
Expand Down