@@ -325,14 +325,14 @@ class Auc(MetricBase):
325325 """
326326
327327 def __init__ (self , name , curve = 'ROC' , num_thresholds = 200 ):
328- super (MetricBase , self ).__init__ (name , curve , num_thresholds )
328+ super (Auc , self ).__init__ (name = name )
329329 self ._curve = curve
330330 self ._num_thresholds = num_thresholds
331331 self ._epsilon = 1e-6
332- self .tp_list = np .ndarray ((num_thresholds , ))
333- self .fn_list = np .ndarray ((num_thresholds , ))
334- self .tn_list = np .ndarray ((num_thresholds , ))
335- self .fp_list = np .ndarray ((num_thresholds , ))
332+ self .tp_list = np .zeros ((num_thresholds , ))
333+ self .fn_list = np .zeros ((num_thresholds , ))
334+ self .tn_list = np .zeros ((num_thresholds , ))
335+ self .fp_list = np .zeros ((num_thresholds , ))
336336
337337 def update (self , labels , predictions , axis = 1 ):
338338 if not _is_numpy_ (labels ):
@@ -350,12 +350,12 @@ def update(self, labels, predictions, axis=1):
350350 tp , fn , tn , fp = 0 , 0 , 0 , 0
351351 for i , lbl in enumerate (labels ):
352352 if lbl :
353- if predictions [i , 0 ] >= thresh :
353+ if predictions [i , 1 ] >= thresh :
354354 tp += 1
355355 else :
356356 fn += 1
357357 else :
358- if predictions [i , 0 ] >= thresh :
358+ if predictions [i , 1 ] >= thresh :
359359 fp += 1
360360 else :
361361 tn += 1
0 commit comments