@@ -149,18 +149,19 @@ def reset(self) -> None:
149149 self ._num_examples = 0
150150
151151 @reinit__is_reduced
152- def update (self , output : Tuple [Sequence [Any ], Sequence [Sequence [Any ]]]) -> None :
153- candidate , references = output [0 ], output [1 ]
154- multiref_scores = [self ._compute_score (candidate = candidate , reference = reference ,) for reference in references ]
155- score = self ._mutliref_reducer (multiref_scores )
156- precision = score .precision ()
157- recall = score .recall ()
158- self ._precision += precision
159- self ._recall += recall
160- precision_recall = precision * recall
161- if precision_recall > 0 : # avoid zero division
162- self ._fmeasure += precision_recall / ((1 - self ._alpha ) * precision + self ._alpha * recall )
163- self ._num_examples += 1
152+ def update (self , output : Tuple [Sequence [Sequence [Any ]], Sequence [Sequence [Sequence [Any ]]]]) -> None :
153+ candidates , references = output
154+ for _candidate , _reference in zip (candidates , references ):
155+ multiref_scores = [self ._compute_score (candidate = _candidate , reference = _ref ,) for _ref in _reference ]
156+ score = self ._mutliref_reducer (multiref_scores )
157+ precision = score .precision ()
158+ recall = score .recall ()
159+ self ._precision += precision
160+ self ._recall += recall
161+ precision_recall = precision * recall
162+ if precision_recall > 0 : # avoid zero division
163+ self ._fmeasure += precision_recall / ((1 - self ._alpha ) * precision + self ._alpha * recall )
164+ self ._num_examples += 1
164165
165166 @sync_all_reduce ("_precision" , "_recall" , "_fmeasure" , "_num_examples" )
166167 def compute (self ) -> Mapping :
@@ -192,8 +193,8 @@ class RougeN(_BaseRouge):
192193 __ https://www.aclweb.org/anthology/W04-1013.pdf
193194
194195 - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
195- - `y_pred` must be a sequence of tokens.
196- - `y` must be a list of sequence of tokens.
196+ - `y_pred` (list(list(str))) must be a sequence of tokens.
197+ - `y` (list(list(list(str))) must be a list of sequence of tokens.
197198
198199 Args:
199200 ngram: ngram order (default: 4).
@@ -222,7 +223,7 @@ class RougeN(_BaseRouge):
222223 "there is a cat on the mat".split()
223224 ]
224225
225- m.update((candidate, references))
226+ m.update(([ candidate], [ references] ))
226227
227228 m.compute()
228229 # {'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4, 'Rouge-2-F': 0.4}
@@ -260,8 +261,8 @@ class RougeL(_BaseRouge):
260261 __ https://www.aclweb.org/anthology/W04-1013.pdf
261262
262263 - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
263- - `y_pred` must be a sequence of tokens.
264- - `y` must be a list of sequence of tokens.
264+ - `y_pred` (list(list(str))) must be a sequence of tokens.
265+ - `y` (list(list(list(str))) must be a list of sequence of tokens.
265266
266267 Args:
267268 multiref: reduces scores for multi references. Valid values are "best" and "average" (default: "average").
@@ -288,7 +289,7 @@ class RougeL(_BaseRouge):
288289 "there is a cat on the mat".split()
289290 ]
290291
291- m.update((candidate, references))
292+ m.update(([ candidate], [ references] ))
292293
293294 m.compute()
294295 # {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5}
@@ -320,8 +321,8 @@ class Rouge(Metric):
320321 __ https://www.aclweb.org/anthology/W04-1013.pdf
321322
322323 - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
323- - `y_pred` must be a sequence of tokens.
324- - `y` must be a list of sequence of tokens.
324+ - `y_pred` (list(list(str))) must be a sequence of tokens.
325+ - `y` (list(list(list(str))) must be a list of sequence of tokens.
325326
326327 Args:
327328 variants: set of metrics computed. Valid inputs are "L" and integer 1 <= n <= 9.
@@ -349,13 +350,15 @@ class Rouge(Metric):
349350 "there is a cat on the mat".split()
350351 ]
351352
352- m.update((candidate, references))
353+ m.update(([ candidate], [ references] ))
353354
354355 m.compute()
355356 # {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5, 'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4,
356357 # 'Rouge-2-F': 0.4}
357358
358359 .. versionadded:: 0.4.5
360+ .. versionchanged:: 0.5.0
361+ Changed input type to work on batch of inputs
359362 """
360363
361364 def __init__ (
@@ -388,7 +391,7 @@ def reset(self) -> None:
388391 m .reset ()
389392
390393 @reinit__is_reduced
391- def update (self , output : Tuple [Sequence [Any ], Sequence [Sequence [Any ]]]) -> None :
394+ def update (self , output : Tuple [Sequence [Sequence [ Any ]] , Sequence [Sequence [Sequence [ Any ] ]]]) -> None :
392395 for m in self .internal_metrics :
393396 m .update (output )
394397
0 commit comments