Skip to content

Commit f67d45e

Browse files
authored
Use global event counts for correct filters in loops (#425)
1 parent 8714d5e commit f67d45e

File tree

3 files changed

+157
-28
lines changed

3 files changed

+157
-28
lines changed

lib/axon/loop.ex

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,7 +1868,9 @@ defmodule Axon.Loop do
18681868
Logger.debug("Axon.Loop fired event #{inspect(event)}")
18691869
end
18701870

1871-
if filter.(state) do
1871+
state = update_counts(state, event)
1872+
1873+
if filter.(state, event) do
18721874
case handler.(state) do
18731875
{:continue, %State{} = state} ->
18741876
if debug? do
@@ -1908,6 +1910,10 @@ defmodule Axon.Loop do
19081910
end)
19091911
end
19101912

1913+
defp update_counts(%State{event_counts: event_counts} = state, event) do
1914+
%{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)}
1915+
end
1916+
19111917
# Halts an epoch during looping
19121918
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
19131919
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
@@ -2155,40 +2161,70 @@ defmodule Axon.Loop do
21552161
# Builds a filter function from an atom, keyword list, or function. A
21562162
# valid filter is an atom which matches on of the valid predicates `:always`
21572163
# or `:once`, a keyword which matches one of the valid predicate-value pairs
2158-
# such as `every: N`, or a function which takes loop state and returns `true`
2159-
# or `false`.
2160-
#
2161-
# TODO(seanmor5): In order to handle custom events and predicate filters,
2162-
# we will need to track event firings in the loop state.
2164+
# such as `every: N`, or a function which takes loop state and the current event
2165+
# and returns `true` to run the handler of `false` to avoid it.
21632166
defp build_filter_fn(filter) do
21642167
case filter do
21652168
:always ->
2166-
fn _ -> true end
2169+
fn _, _ -> true end
21672170

21682171
:first ->
2169-
fn
2170-
%State{epoch: 0, iteration: 0} -> true
2171-
_ -> false
2172+
fn %State{event_counts: counts}, event ->
2173+
counts[event] == 1
21722174
end
21732175

2174-
[{:every, n} | _] ->
2175-
fn %State{iteration: iter} ->
2176-
Kernel.rem(iter, n) == 0
2177-
end
2176+
filters when is_list(filters) ->
2177+
Enum.reduce(filters, fn _, _ -> true end, fn
2178+
{:every, n}, acc ->
2179+
fn state, event ->
2180+
acc.(state, event) and filter_every_n(state, event, n)
2181+
end
2182+
2183+
{:before, n}, acc ->
2184+
fn state, event ->
2185+
acc.(state, event) and filter_before_n(state, event, n)
2186+
end
2187+
2188+
{:after, n}, acc ->
2189+
fn state, event ->
2190+
acc.(state, event) and filter_after_n(state, event, n)
2191+
end
2192+
2193+
{:once, n}, acc ->
2194+
fn state, event ->
2195+
acc.(state, event) and filter_once_n(state, event, n)
2196+
end
2197+
end)
21782198

2179-
fun when is_function(fun, 1) ->
2199+
fun when is_function(fun, 2) ->
21802200
fun
21812201

21822202
invalid ->
21832203
raise ArgumentError,
21842204
"Invalid filter #{inspect(invalid)}, a valid filter" <>
21852205
" is an atom which matches a valid filter predicate" <>
21862206
" such as :always or :once, a keyword of predicate-value" <>
2187-
" pairs such as every: N, or an arity-1 function which takes" <>
2188-
" loop state and returns true or false"
2207+
" pairs such as every: N, or an arity-2 function which takes" <>
2208+
" loop state and current event and returns true or false"
21892209
end
21902210
end
21912211

2212+
defp filter_every_n(%State{event_counts: counts}, event, n) do
2213+
rem(counts[event] - 1, n) == 0
2214+
end
2215+
2216+
defp filter_after_n(%State{event_counts: counts}, event, n) do
2217+
counts[event] > n
2218+
end
2219+
2220+
defp filter_before_n(%State{event_counts: counts}, event, n) do
2221+
counts[event] < n
2222+
end
2223+
2224+
defp filter_once_n(%State{event_counts: counts}, event, n) do
2225+
counts[event] == n
2226+
end
2227+
21922228
# JIT-compiles the given function if jit_compile? is true
21932229
# otherwise just applies the function with the given arguments
21942230
defp maybe_jit(fun, args, jit_compile?, jit_opts) do

lib/axon/loop/state.ex

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ defmodule Axon.Loop.State do
3939
`handler_metadata` is a metadata field for storing loop handler metadata.
4040
For example, loop checkpoints with specific metric criteria can store
4141
previous best metrics in the handler meta for use between iterations.
42+
43+
`event_counts` is a metadata field which stores information about the number
44+
of times each event has been fired. This is useful when creating custom filters.
4245
"""
4346
@enforce_keys [:step_state]
4447
defstruct [
@@ -49,6 +52,16 @@ defmodule Axon.Loop.State do
4952
iteration: 0,
5053
max_iteration: -1,
5154
metrics: %{},
52-
times: %{}
55+
times: %{},
56+
event_counts: %{
57+
started: 0,
58+
epoch_started: 0,
59+
iteration_started: 0,
60+
iteration_completed: 0,
61+
epoch_completed: 0,
62+
epoch_halted: 0,
63+
halted: 0,
64+
completed: 0
65+
}
5366
]
5467
end

test/axon/loop_test.exs

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ defmodule Axon.LoopTest do
645645
end)
646646

647647
model
648-
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
648+
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd, log: -1)
649649
|> send_handler(event, filter)
650650
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: iterations)
651651
end
@@ -663,9 +663,7 @@ defmodule Axon.LoopTest do
663663
end
664664

665665
test "supports an :always filter" do
666-
ExUnit.CaptureIO.capture_io(fn ->
667-
run_dummy_loop!(:iteration_started, :always, 5, 10)
668-
end)
666+
run_dummy_loop!(:iteration_started, :always, 5, 10)
669667

670668
for _ <- 1..50 do
671669
assert_received :iteration_started
@@ -675,15 +673,97 @@ defmodule Axon.LoopTest do
675673
end
676674

677675
test "supports an every: n filter" do
678-
ExUnit.CaptureIO.capture_io(fn ->
679-
run_dummy_loop!(:iteration_started, [every: 2], 5, 10)
680-
end)
676+
run_dummy_loop!(:iteration_started, [every: 2], 5, 10)
681677

682678
for _ <- 1..25 do
683679
assert_received :iteration_started
684680
end
685681

686682
refute_received :iteration_started
683+
684+
run_dummy_loop!(:iteration_completed, [every: 3], 3, 10)
685+
686+
for _ <- 1..10 do
687+
assert_received :iteration_completed
688+
end
689+
690+
refute_received :iteration_completed
691+
end
692+
693+
test "supports after: n filter" do
694+
run_dummy_loop!(:iteration_started, [after: 10], 5, 10)
695+
696+
for _ <- 1..40 do
697+
assert_received :iteration_started
698+
end
699+
700+
refute_received :iteration_started
701+
702+
run_dummy_loop!(:iteration_completed, [after: 10], 5, 10)
703+
704+
for _ <- 1..40 do
705+
assert_received :iteration_completed
706+
end
707+
708+
refute_received :iteration_completed
709+
end
710+
711+
test "supports before: n filter" do
712+
run_dummy_loop!(:iteration_started, [before: 10], 5, 10)
713+
714+
for _ <- 1..9 do
715+
assert_received :iteration_started
716+
end
717+
718+
refute_received :iteration_started
719+
720+
run_dummy_loop!(:iteration_completed, [before: 10], 5, 10)
721+
722+
for _ <- 1..9 do
723+
assert_received :iteration_completed
724+
end
725+
726+
refute_received :iteration_completed
727+
end
728+
729+
test "supports once: n filter" do
730+
run_dummy_loop!(:iteration_started, [once: 30], 5, 10)
731+
732+
assert_received :iteration_started
733+
refute_received :iteration_started
734+
735+
run_dummy_loop!(:iteration_completed, [once: 30], 5, 10)
736+
737+
assert_received :iteration_completed
738+
refute_received :iteration_completed
739+
end
740+
741+
test "supports hybrid filter" do
742+
run_dummy_loop!(:iteration_started, [every: 2, after: 10, before: 40], 5, 10)
743+
744+
for _ <- 1..15 do
745+
assert_received :iteration_started
746+
end
747+
748+
refute_received :iteration_started
749+
end
750+
751+
test "supports :first filter" do
752+
run_dummy_loop!(:iteration_started, :first, 5, 10)
753+
754+
assert_received :iteration_started
755+
refute_received :iteration_started
756+
end
757+
758+
test "supports function filter" do
759+
fun = fn
760+
%{event_counts: counts}, event -> counts[event] == 5
761+
end
762+
763+
run_dummy_loop!(:iteration_started, fun, 5, 10)
764+
765+
assert_received :iteration_started
766+
refute_received :iteration_started
687767
end
688768
end
689769

@@ -814,7 +894,7 @@ defmodule Axon.LoopTest do
814894
assert Map.has_key?(metrics, "validation_accuracy")
815895
{:continue, state}
816896
end,
817-
fn %{epoch: epoch} -> epoch == 1 end
897+
fn %{epoch: epoch}, _ -> epoch == 1 end
818898
)
819899
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
820900
end)
@@ -846,7 +926,7 @@ defmodule Axon.LoopTest do
846926

847927
{:continue, state}
848928
end,
849-
fn %{epoch: epoch} -> epoch == 1 end
929+
fn %{epoch: epoch}, _ -> epoch == 1 end
850930
)
851931
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
852932
end)
@@ -934,7 +1014,7 @@ defmodule Axon.LoopTest do
9341014

9351015
{:continue, state}
9361016
end,
937-
fn %{epoch: epoch} -> epoch == 1 end
1017+
fn %{epoch: epoch}, _ -> epoch == 1 end
9381018
)
9391019
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
9401020
end)

0 commit comments

Comments
 (0)