@@ -645,7 +645,7 @@ defmodule Axon.LoopTest do
645645 end )
646646
647647 model
648- |> Axon.Loop . trainer ( :binary_cross_entropy , :sgd )
648+ |> Axon.Loop . trainer ( :binary_cross_entropy , :sgd , log: - 1 )
649649 |> send_handler ( event , filter )
650650 |> Axon.Loop . run ( data , % { } , epochs: epochs , iterations: iterations )
651651 end
@@ -663,9 +663,7 @@ defmodule Axon.LoopTest do
663663 end
664664
665665 test "supports an :always filter" do
666- ExUnit.CaptureIO . capture_io ( fn ->
667- run_dummy_loop! ( :iteration_started , :always , 5 , 10 )
668- end )
666+ run_dummy_loop! ( :iteration_started , :always , 5 , 10 )
669667
670668 for _ <- 1 .. 50 do
671669 assert_received :iteration_started
@@ -675,15 +673,97 @@ defmodule Axon.LoopTest do
675673 end
676674
677675 test "supports an every: n filter" do
678- ExUnit.CaptureIO . capture_io ( fn ->
679- run_dummy_loop! ( :iteration_started , [ every: 2 ] , 5 , 10 )
680- end )
676+ run_dummy_loop! ( :iteration_started , [ every: 2 ] , 5 , 10 )
681677
682678 for _ <- 1 .. 25 do
683679 assert_received :iteration_started
684680 end
685681
686682 refute_received :iteration_started
683+
684+ run_dummy_loop! ( :iteration_completed , [ every: 3 ] , 3 , 10 )
685+
686+ for _ <- 1 .. 10 do
687+ assert_received :iteration_completed
688+ end
689+
690+ refute_received :iteration_completed
691+ end
692+
693+ test "supports after: n filter" do
694+ run_dummy_loop! ( :iteration_started , [ after: 10 ] , 5 , 10 )
695+
696+ for _ <- 1 .. 40 do
697+ assert_received :iteration_started
698+ end
699+
700+ refute_received :iteration_started
701+
702+ run_dummy_loop! ( :iteration_completed , [ after: 10 ] , 5 , 10 )
703+
704+ for _ <- 1 .. 40 do
705+ assert_received :iteration_completed
706+ end
707+
708+ refute_received :iteration_completed
709+ end
710+
711+ test "supports before: n filter" do
712+ run_dummy_loop! ( :iteration_started , [ before: 10 ] , 5 , 10 )
713+
714+ for _ <- 1 .. 9 do
715+ assert_received :iteration_started
716+ end
717+
718+ refute_received :iteration_started
719+
720+ run_dummy_loop! ( :iteration_completed , [ before: 10 ] , 5 , 10 )
721+
722+ for _ <- 1 .. 9 do
723+ assert_received :iteration_completed
724+ end
725+
726+ refute_received :iteration_completed
727+ end
728+
729+ test "supports once: n filter" do
730+ run_dummy_loop! ( :iteration_started , [ once: 30 ] , 5 , 10 )
731+
732+ assert_received :iteration_started
733+ refute_received :iteration_started
734+
735+ run_dummy_loop! ( :iteration_completed , [ once: 30 ] , 5 , 10 )
736+
737+ assert_received :iteration_completed
738+ refute_received :iteration_completed
739+ end
740+
741+ test "supports hybrid filter" do
742+ run_dummy_loop! ( :iteration_started , [ every: 2 , after: 10 , before: 40 ] , 5 , 10 )
743+
744+ for _ <- 1 .. 15 do
745+ assert_received :iteration_started
746+ end
747+
748+ refute_received :iteration_started
749+ end
750+
751+ test "supports :first filter" do
752+ run_dummy_loop! ( :iteration_started , :first , 5 , 10 )
753+
754+ assert_received :iteration_started
755+ refute_received :iteration_started
756+ end
757+
758+ test "supports function filter" do
759+ fun = fn
760+ % { event_counts: counts } , event -> counts [ event ] == 5
761+ end
762+
763+ run_dummy_loop! ( :iteration_started , fun , 5 , 10 )
764+
765+ assert_received :iteration_started
766+ refute_received :iteration_started
687767 end
688768 end
689769
@@ -814,7 +894,7 @@ defmodule Axon.LoopTest do
814894 assert Map . has_key? ( metrics , "validation_accuracy" )
815895 { :continue , state }
816896 end ,
817- fn % { epoch: epoch } -> epoch == 1 end
897+ fn % { epoch: epoch } , _ -> epoch == 1 end
818898 )
819899 |> Axon.Loop . run ( data , % { } , epochs: 5 , iterations: 5 )
820900 end )
@@ -846,7 +926,7 @@ defmodule Axon.LoopTest do
846926
847927 { :continue , state }
848928 end ,
849- fn % { epoch: epoch } -> epoch == 1 end
929+ fn % { epoch: epoch } , _ -> epoch == 1 end
850930 )
851931 |> Axon.Loop . run ( data , % { } , epochs: 5 , iterations: 5 )
852932 end )
@@ -934,7 +1014,7 @@ defmodule Axon.LoopTest do
9341014
9351015 { :continue , state }
9361016 end ,
937- fn % { epoch: epoch } -> epoch == 1 end
1017+ fn % { epoch: epoch } , _ -> epoch == 1 end
9381018 )
9391019 |> Axon.Loop . run ( data , % { } , epochs: 5 , iterations: 5 )
9401020 end )
0 commit comments