Skip to content

Commit 060dc85

Browse files
authored
Return Python float instead of numpy.float64 in sklearn metrics (#2612)
1 parent 3cbc28f commit 060dc85

8 files changed

Lines changed: 18 additions & 16 deletions

File tree

metrics/accuracy/accuracy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,7 @@ def _info(self):
8383

8484
def _compute(self, predictions, references, normalize=True, sample_weight=None):
8585
return {
86-
"accuracy": accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight),
86+
"accuracy": accuracy_score(
87+
references, predictions, normalize=normalize, sample_weight=sample_weight
88+
).tolist(),
8789
}

metrics/f1/f1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,5 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b
106106
pos_label=pos_label,
107107
average=average,
108108
sample_weight=sample_weight,
109-
),
109+
).tolist(),
110110
}

metrics/glue/glue.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@
8181

8282

8383
def simple_accuracy(preds, labels):
84-
return (preds == labels).mean()
84+
return (preds == labels).mean().tolist()
8585

8686

8787
def acc_and_f1(preds, labels):
8888
acc = simple_accuracy(preds, labels)
89-
f1 = f1_score(y_true=labels, y_pred=preds)
89+
f1 = f1_score(y_true=labels, y_pred=preds).tolist()
9090
return {
9191
"accuracy": acc,
9292
"f1": f1,
9393
}
9494

9595

9696
def pearson_and_spearman(preds, labels):
97-
pearson_corr = pearsonr(preds, labels)[0]
98-
spearman_corr = spearmanr(preds, labels)[0]
97+
pearson_corr = pearsonr(preds, labels)[0].tolist()
98+
spearman_corr = spearmanr(preds, labels)[0].tolist()
9999
return {
100100
"pearson": pearson_corr,
101101
"spearmanr": spearman_corr,

metrics/indic_glue/indic_glue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@
7474

7575

7676
def simple_accuracy(preds, labels):
77-
return (preds == labels).mean()
77+
return (preds == labels).mean().tolist()
7878

7979

8080
def acc_and_f1(preds, labels):
8181
acc = simple_accuracy(preds, labels)
82-
f1 = f1_score(y_true=labels, y_pred=preds)
82+
f1 = f1_score(y_true=labels, y_pred=preds).tolist()
8383
return {
8484
"accuracy": acc,
8585
"f1": f1,
@@ -99,7 +99,7 @@ def precision_at_10(en_sentvecs, in_sentvecs):
9999
actual = np.array(range(n))
100100
preds = sim.argsort(axis=1)[:, :10]
101101
matches = np.any(preds == actual[:, None], axis=1)
102-
return matches.mean()
102+
return matches.mean().tolist()
103103

104104

105105
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)

metrics/matthews_correlation/matthews_correlation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,5 @@ def _info(self):
8282

8383
def _compute(self, predictions, references, sample_weight=None):
8484
return {
85-
"matthews_correlation": matthews_corrcoef(references, predictions, sample_weight=sample_weight),
85+
"matthews_correlation": matthews_corrcoef(references, predictions, sample_weight=sample_weight).tolist(),
8686
}

metrics/precision/precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,5 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b
108108
pos_label=pos_label,
109109
average=average,
110110
sample_weight=sample_weight,
111-
),
111+
).tolist(),
112112
}

metrics/recall/recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,5 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b
108108
pos_label=pos_label,
109109
average=average,
110110
sample_weight=sample_weight,
111-
),
111+
).tolist(),
112112
}

metrics/super_glue/super_glue.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@
107107

108108

109109
def simple_accuracy(preds, labels):
110-
return (preds == labels).mean()
110+
return (preds == labels).mean().tolist()
111111

112112

113113
def acc_and_f1(preds, labels, f1_avg="binary"):
114114
acc = simple_accuracy(preds, labels)
115-
f1 = f1_score(y_true=labels, y_pred=preds, average=f1_avg)
115+
f1 = f1_score(y_true=labels, y_pred=preds, average=f1_avg).tolist()
116116
return {
117117
"accuracy": acc,
118118
"f1": f1,
@@ -138,9 +138,9 @@ def evaluate_multirc(ids_preds, labels):
138138
f1s.append(f1)
139139
em = int(sum([p == l for p, l in preds_labels]) == len(preds_labels))
140140
ems.append(em)
141-
f1_m = sum(f1s) / len(f1s)
141+
f1_m = (sum(f1s) / len(f1s)).tolist()
142142
em = sum(ems) / len(ems)
143-
f1_a = f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds])
143+
f1_a = f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds]).tolist()
144144
return {"exact_match": em, "f1_m": f1_m, "f1_a": f1_a}
145145

146146

0 commit comments

Comments
 (0)