1+ from dataclasses import dataclass
2+ import numpy as np
13import sklearn .metrics
4+ from typing import Optional
25from abc import ABC , ABCMeta , abstractmethod
36
47from ...models .base import BaseModel
58from ...datasets .base import Dataset
69
710
11+ @dataclass
12+ class MetricResult :
13+ metric : "PerformanceMetric"
14+ value : float
15+ affected_samples : int
16+ raw_values : Optional [np .ndarray ] = None
17+
18+
819class PerformanceMetric (ABC ):
920 name : str
1021 greater_is_better = True
1122
1223 @abstractmethod
13- def __call__ (self , model : BaseModel , dataset : Dataset ) -> float :
24+ def __call__ (self , model : BaseModel , dataset : Dataset ) -> MetricResult :
1425 ...
1526
1627
1728class ClassificationPerformanceMetric (PerformanceMetric , metaclass = ABCMeta ):
18- def __call__ (self , model : BaseModel , dataset : Dataset ) -> float :
29+ def __call__ (self , model : BaseModel , dataset : Dataset ) -> MetricResult :
1930 if not model .is_classification :
2031 raise ValueError (f"Metric '{ self .name } ' is only defined for classification models." )
2132
22- y_true = dataset .df [dataset .target ]
23- y_pred = model .predict (dataset ).prediction
33+ y_true = np . asarray ( dataset .df [dataset .target ])
34+ y_pred = np . asarray ( model .predict (dataset ).prediction )
2435
25- return self ._calculate_metric (y_true , y_pred , model )
36+ value = self ._calculate_metric (y_true , y_pred , model )
37+ num_affected = self ._calculate_affected_samples (y_true , y_pred , model )
38+ return MetricResult (self , value , num_affected )
2639
2740 @abstractmethod
28- def _calculate_metric (self , y_true , y_pred , model : BaseModel ) -> float :
41+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ) -> MetricResult :
2942 ...
3043
44+ def _calculate_affected_samples (self , y_true : np .ndarray , y_pred : np .ndarray , model : BaseModel ) -> int :
45+ return len (y_true )
46+
3147
3248class Accuracy (ClassificationPerformanceMetric ):
3349 name = "Accuracy"
3450 greater_is_better = True
3551
36- def _calculate_metric (self , y_true , y_pred , model : BaseModel ):
52+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ):
3753 return sklearn .metrics .accuracy_score (y_true , y_pred )
3854
3955
4056class BalancedAccuracy (ClassificationPerformanceMetric ):
4157 name = "Balanced Accuracy"
4258 greater_is_better = True
4359
44- def _calculate_metric (self , y_true , y_pred , model : BaseModel ):
60+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ):
4561 return sklearn .metrics .balanced_accuracy_score (y_true , y_pred )
4662
4763
4864class SklearnClassificationScoreMixin :
4965 _sklearn_metric : str
5066
51- def _calculate_metric (self , y_true , y_pred , model : BaseModel ):
67+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ):
5268 metric_fn = getattr (sklearn .metrics , self ._sklearn_metric )
5369 if model .is_binary_classification :
5470 return metric_fn (
@@ -67,57 +83,78 @@ class F1Score(SklearnClassificationScoreMixin, ClassificationPerformanceMetric):
6783 greater_is_better = True
6884 _sklearn_metric = "f1_score"
6985
86+ def _calculate_affected_samples (self , y_true : np .ndarray , y_pred : np .ndarray , model : BaseModel ) -> int :
87+ if model .is_binary_classification :
88+ # F1 score will not be affected by true negatives
89+ neg = model .meta .classification_labels [0 ]
90+ tn = ((y_true == neg ) & (y_pred == neg )).sum ()
91+ return len (y_true ) - tn
92+
93+ return len (y_true )
94+
7095
7196class Precision (SklearnClassificationScoreMixin , ClassificationPerformanceMetric ):
7297 name = "Precision"
7398 greater_is_better = True
7499 _sklearn_metric = "precision_score"
75100
101+ def _calculate_affected_samples (self , y_true : np .ndarray , y_pred : np .ndarray , model : BaseModel ) -> int :
102+ if model .is_binary_classification :
103+ return (y_pred == model .meta .classification_labels [1 ]).sum ()
104+
105+ return len (y_true )
106+
76107
77108class Recall (SklearnClassificationScoreMixin , ClassificationPerformanceMetric ):
78109 name = "Recall"
79110 greater_is_better = True
80111 _sklearn_metric = "recall_score"
81112
113+ def _calculate_affected_samples (self , y_true : np .ndarray , y_pred : np .ndarray , model : BaseModel ) -> int :
114+ if model .is_binary_classification :
115+ return (y_true == model .meta .classification_labels [1 ]).sum ()
116+
117+ return len (y_true )
118+
82119
83120class AUC (PerformanceMetric ):
84121 name = "ROC AUC"
85122 greater_is_better = True
86123
87- def __call__ (self , model : BaseModel , dataset : Dataset ) -> float :
124+ def __call__ (self , model : BaseModel , dataset : Dataset ) -> MetricResult :
88125 y_true = dataset .df [dataset .target ]
89126 if model .is_binary_classification :
90127 y_score = model .predict (dataset ).raw [:, 1 ]
91128 else :
92129 y_score = model .predict (dataset ).all_predictions
93130
94- return sklearn .metrics .roc_auc_score (
95- y_true ,
96- y_score ,
97- multi_class = "ovo" ,
98- labels = model .meta .classification_labels ,
131+ value = sklearn .metrics .roc_auc_score (
132+ y_true , y_score , multi_class = "ovo" , labels = model .meta .classification_labels
99133 )
100134
135+ return MetricResult (self , value , len (y_true ))
136+
101137
102138class RegressionPerformanceMetric (PerformanceMetric ):
103- def __call__ (self , model : BaseModel , dataset : Dataset ) -> float :
139+ def __call__ (self , model : BaseModel , dataset : Dataset ) -> MetricResult :
104140 if not model .is_regression :
105141 raise ValueError (f"Metric '{ self .name } ' is only defined for regression models." )
106142
107143 y_true = dataset .df [dataset .target ]
108144 y_pred = model .predict (dataset ).prediction
109145
110- return self ._calculate_metric (y_true , y_pred , model )
146+ value = self ._calculate_metric (y_true , y_pred , model )
147+ return MetricResult (self , value , len (y_true ))
111148
112149 @abstractmethod
113- def _calculate_metric (self , y_true , y_pred , model : BaseModel ) -> float :
150+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ) -> float :
114151 ...
115152
116153
117154class SklearnRegressionScoreMixin :
118155 _sklearn_metric : str
119156
120- def _calculate_metric (self , y_true , y_pred , model : BaseModel ):
157+ def _calculate_metric (self , y_true : np . ndarray , y_pred : np . ndarray , model : BaseModel ):
121158 metric_fn = getattr (sklearn .metrics , self ._sklearn_metric )
122159 return metric_fn (y_true , y_pred )
123160
0 commit comments