@@ -20,83 +20,83 @@ def test_wrong_input_shapes():
2020 m .update ((torch .rand (4 , 1 ), torch .rand (4 )))
2121
2222
23- def test_compute (available_device ):
24- a = torch .randn (4 )
25- b = torch .randn (4 )
26- c = torch .randn (4 )
27- d = torch .randn (4 )
28- ground_truth = torch .randn (4 )
23+ def test_compute ():
24+ a = np . random .randn (4 )
25+ b = np . random .randn (4 )
26+ c = np . random .randn (4 )
27+ d = np . random .randn (4 )
28+ ground_truth = np . random .randn (4 )
2929
30- m = CanberraMetric (device = available_device )
31- assert m ._device == torch .device (available_device )
30+ m = CanberraMetric ()
3231
3332 canberra = DistanceMetric .get_metric ("canberra" )
3433
35- m .update ((a , ground_truth ))
36- np_sum = (torch .abs (ground_truth - a ) / (torch .abs (a ) + torch .abs (ground_truth ))).sum ()
34+ m .update ((torch . from_numpy ( a ), torch . from_numpy ( ground_truth ) ))
35+ np_sum = (np .abs (ground_truth - a ) / (np .abs (a ) + np .abs (ground_truth ))).sum ()
3736 assert m .compute () == pytest .approx (np_sum )
38- assert canberra .pairwise ([a . cpu (). numpy () , ground_truth . cpu (). numpy () ])[0 ][1 ] == pytest .approx (np_sum )
37+ assert canberra .pairwise ([a , ground_truth ])[0 ][1 ] == pytest .approx (np_sum )
3938
40- m .update ((b , ground_truth ))
41- np_sum += ((torch .abs (ground_truth - b )) / (torch .abs (b ) + torch .abs (ground_truth ))).sum ()
39+ m .update ((torch . from_numpy ( b ), torch . from_numpy ( ground_truth ) ))
40+ np_sum += ((np .abs (ground_truth - b )) / (np .abs (b ) + np .abs (ground_truth ))).sum ()
4241 assert m .compute () == pytest .approx (np_sum )
4342 v1 = np .hstack ([a , b ])
4443 v2 = np .hstack ([ground_truth , ground_truth ])
4544 assert canberra .pairwise ([v1 , v2 ])[0 ][1 ] == pytest .approx (np_sum )
4645
47- m .update ((c , ground_truth ))
48- np_sum += ((torch .abs (ground_truth - c )) / (torch .abs (c ) + torch .abs (ground_truth ))).sum ()
46+ m .update ((torch . from_numpy ( c ), torch . from_numpy ( ground_truth ) ))
47+ np_sum += ((np .abs (ground_truth - c )) / (np .abs (c ) + np .abs (ground_truth ))).sum ()
4948 assert m .compute () == pytest .approx (np_sum )
5049 v1 = np .hstack ([v1 , c ])
5150 v2 = np .hstack ([v2 , ground_truth ])
5251 assert canberra .pairwise ([v1 , v2 ])[0 ][1 ] == pytest .approx (np_sum )
5352
54- m .update ((d , ground_truth ))
55- np_sum += (torch .abs (ground_truth - d ) / (torch .abs (d ) + torch .abs (ground_truth ))).sum ()
53+ m .update ((torch . from_numpy ( d ), torch . from_numpy ( ground_truth ) ))
54+ np_sum += (np .abs (ground_truth - d ) / (np .abs (d ) + np .abs (ground_truth ))).sum ()
5655 assert m .compute () == pytest .approx (np_sum )
5756 v1 = np .hstack ([v1 , d ])
5857 v2 = np .hstack ([v2 , ground_truth ])
5958 assert canberra .pairwise ([v1 , v2 ])[0 ][1 ] == pytest .approx (np_sum )
6059
6160
62- @pytest .mark .parametrize ("n_times" , range (3 ))
63- @pytest .mark .parametrize (
64- "test_cases" ,
65- [
66- (torch .rand (size = (100 ,)), torch .rand (size = (100 ,)), 10 ),
67- (torch .rand (size = (100 , 1 )), torch .rand (size = (100 , 1 )), 20 ),
68- ],
69- )
70- def test_integration (n_times , test_cases , available_device ):
71- y_pred , y , batch_size = test_cases
61+ def test_integration ():
62+ def _test (y_pred , y , batch_size ):
63+ def update_fn (engine , batch ):
64+ idx = (engine .state .iteration - 1 ) * batch_size
65+ y_true_batch = np_y [idx : idx + batch_size ]
66+ y_pred_batch = np_y_pred [idx : idx + batch_size ]
67+ return torch .from_numpy (y_pred_batch ), torch .from_numpy (y_true_batch )
7268
73- def update_fn (engine , batch ):
74- idx = (engine .state .iteration - 1 ) * batch_size
75- y_true_batch = y [idx : idx + batch_size ]
76- y_pred_batch = y_pred [idx : idx + batch_size ]
77- return y_pred_batch , y_true_batch
69+ engine = Engine (update_fn )
7870
79- engine = Engine (update_fn )
71+ m = CanberraMetric ()
72+ m .attach (engine , "cm" )
8073
81- m = CanberraMetric ( device = available_device )
82- assert m . _device == torch . device ( available_device )
74+ np_y = y . numpy (). ravel ( )
75+ np_y_pred = y_pred . numpy (). ravel ( )
8376
84- m . attach ( engine , "cm " )
77+ canberra = DistanceMetric . get_metric ( "canberra " )
8578
86- canberra = DistanceMetric .get_metric ("canberra" )
79+ data = list (range (y_pred .shape [0 ] // batch_size ))
80+ cm = engine .run (data , max_epochs = 1 ).metrics ["cm" ]
8781
88- data = list (range (y_pred .shape [0 ] // batch_size ))
89- cm = engine .run (data , max_epochs = 1 ).metrics ["cm" ]
82+ assert canberra .pairwise ([np_y_pred , np_y ])[0 ][1 ] == pytest .approx (cm )
9083
91- pred_np = y_pred .cpu ().numpy ().reshape (len (y_pred ), - 1 )
92- true_np = y .cpu ().numpy ().reshape (len (y ), - 1 )
93- expected = np .sum (canberra .pairwise (pred_np , true_np ).diagonal ())
94- assert expected == pytest .approx (cm )
84+ def get_test_cases ():
85+ test_cases = [
86+ (torch .rand (size = (100 ,)), torch .rand (size = (100 ,)), 10 ),
87+ (torch .rand (size = (100 , 1 )), torch .rand (size = (100 , 1 )), 20 ),
88+ ]
89+ return test_cases
9590
91+ for _ in range (5 ):
92+ # check multiple random inputs as random exact occurencies are rare
93+ test_cases = get_test_cases ()
94+ for y_pred , y , batch_size in test_cases :
95+ _test (y_pred , y , batch_size )
9696
97- def test_error_is_not_nan ( available_device ):
98- m = CanberraMetric ( device = available_device )
99- assert m . _device == torch . device ( available_device )
97+
98+ def test_error_is_not_nan ():
99+ m = CanberraMetric ( )
100100 m .update ((torch .zeros (4 ), torch .zeros (4 )))
101101 assert not (torch .isnan (m ._sum_of_errors ).any () or torch .isinf (m ._sum_of_errors ).any ()), m ._sum_of_errors
102102
0 commit comments