Skip to content

Commit 0be46fd

Browse files
authored
[BC-breaking] added ROUGE calculation on batch input (#2259)
* added ROUGE calculation on batch input * updated typehints
1 parent 81e13e1 commit 0be46fd

File tree

2 files changed

+29
-28
lines changed

2 files changed

+29
-28
lines changed

ignite/metrics/nlp/rouge.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,19 @@ def reset(self) -> None:
149149
self._num_examples = 0
150150

151151
@reinit__is_reduced
152-
def update(self, output: Tuple[Sequence[Any], Sequence[Sequence[Any]]]) -> None:
153-
candidate, references = output[0], output[1]
154-
multiref_scores = [self._compute_score(candidate=candidate, reference=reference,) for reference in references]
155-
score = self._mutliref_reducer(multiref_scores)
156-
precision = score.precision()
157-
recall = score.recall()
158-
self._precision += precision
159-
self._recall += recall
160-
precision_recall = precision * recall
161-
if precision_recall > 0: # avoid zero division
162-
self._fmeasure += precision_recall / ((1 - self._alpha) * precision + self._alpha * recall)
163-
self._num_examples += 1
152+
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
153+
candidates, references = output
154+
for _candidate, _reference in zip(candidates, references):
155+
multiref_scores = [self._compute_score(candidate=_candidate, reference=_ref,) for _ref in _reference]
156+
score = self._mutliref_reducer(multiref_scores)
157+
precision = score.precision()
158+
recall = score.recall()
159+
self._precision += precision
160+
self._recall += recall
161+
precision_recall = precision * recall
162+
if precision_recall > 0: # avoid zero division
163+
self._fmeasure += precision_recall / ((1 - self._alpha) * precision + self._alpha * recall)
164+
self._num_examples += 1
164165

165166
@sync_all_reduce("_precision", "_recall", "_fmeasure", "_num_examples")
166167
def compute(self) -> Mapping:
@@ -192,8 +193,8 @@ class RougeN(_BaseRouge):
192193
__ https://www.aclweb.org/anthology/W04-1013.pdf
193194
194195
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
195-
- `y_pred` must be a sequence of tokens.
196-
- `y` must be a list of sequence of tokens.
196+
- `y_pred` (list(list(str))) must be a sequence of tokens.
197+
- `y` (list(list(list(str))) must be a list of sequence of tokens.
197198
198199
Args:
199200
ngram: ngram order (default: 4).
@@ -222,7 +223,7 @@ class RougeN(_BaseRouge):
222223
"there is a cat on the mat".split()
223224
]
224225
225-
m.update((candidate, references))
226+
m.update(([candidate], [references]))
226227
227228
m.compute()
228229
# {'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4, 'Rouge-2-F': 0.4}
@@ -260,8 +261,8 @@ class RougeL(_BaseRouge):
260261
__ https://www.aclweb.org/anthology/W04-1013.pdf
261262
262263
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
263-
- `y_pred` must be a sequence of tokens.
264-
- `y` must be a list of sequence of tokens.
264+
- `y_pred` (list(list(str))) must be a sequence of tokens.
265+
- `y` (list(list(list(str))) must be a list of sequence of tokens.
265266
266267
Args:
267268
multiref: reduces scores for multi references. Valid values are "best" and "average" (default: "average").
@@ -288,7 +289,7 @@ class RougeL(_BaseRouge):
288289
"there is a cat on the mat".split()
289290
]
290291
291-
m.update((candidate, references))
292+
m.update(([candidate], [references]))
292293
293294
m.compute()
294295
# {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5}
@@ -320,8 +321,8 @@ class Rouge(Metric):
320321
__ https://www.aclweb.org/anthology/W04-1013.pdf
321322
322323
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
323-
- `y_pred` must be a sequence of tokens.
324-
- `y` must be a list of sequence of tokens.
324+
- `y_pred` (list(list(str))) must be a sequence of tokens.
325+
- `y` (list(list(list(str))) must be a list of sequence of tokens.
325326
326327
Args:
327328
variants: set of metrics computed. Valid inputs are "L" and integer 1 <= n <= 9.
@@ -349,13 +350,15 @@ class Rouge(Metric):
349350
"there is a cat on the mat".split()
350351
]
351352
352-
m.update((candidate, references))
353+
m.update(([candidate], [references]))
353354
354355
m.compute()
355356
# {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5, 'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4,
356357
# 'Rouge-2-F': 0.4}
357358
358359
.. versionadded:: 0.4.5
360+
.. versionchanged:: 0.5.0
361+
Changed input type to work on batch of inputs
359362
"""
360363

361364
def __init__(
@@ -388,7 +391,7 @@ def reset(self) -> None:
388391
m.reset()
389392

390393
@reinit__is_reduced
391-
def update(self, output: Tuple[Sequence[Any], Sequence[Sequence[Any]]]) -> None:
394+
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
392395
for m in self.internal_metrics:
393396
m.update(output)
394397

tests/ignite/metrics/nlp/test_rouge.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_wrong_inputs():
7373
def test_rouge_n_alpha(ngram, candidate, reference, expected):
7474
for alpha in [0, 1, 0.3, 0.5, 0.8]:
7575
rouge = RougeN(ngram=ngram, alpha=alpha)
76-
rouge.update((candidate, [reference]))
76+
rouge.update(([candidate], [[reference]]))
7777
results = rouge.compute()
7878
assert results[f"Rouge-{ngram}-P"] == expected[0]
7979
assert results[f"Rouge-{ngram}-R"] == expected[1]
@@ -110,8 +110,7 @@ def test_rouge_metrics(candidates, references):
110110
lower_split_candidates = [candidate.lower().split() for candidate in candidates]
111111

112112
m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5)
113-
for candidate, references_per_candidate in zip(lower_split_candidates, lower_split_references):
114-
m.update((candidate, references_per_candidate))
113+
m.update((lower_split_candidates, lower_split_references))
115114
results = m.compute()
116115

117116
for key in ["1", "2", "4", "L"]:
@@ -136,7 +135,7 @@ def update(_, i):
136135
candidate, references = data[i + size * rank]
137136
lower_split_references = [reference.lower().split() for reference in references[0]]
138137
lower_split_candidate = candidate[0].lower().split()
139-
return lower_split_candidate, lower_split_references
138+
return [lower_split_candidate], [lower_split_references]
140139

141140
def _test(metric_device):
142141
engine = Engine(update)
@@ -158,11 +157,10 @@ def _test(metric_device):
158157
)
159158
rouge_1_f, rouge_2_f, rouge_l_f = (0, 0, 0)
160159
for candidate, references in data:
161-
scores = evaluator.get_scores([candidate[0]], [references[0]])
160+
scores = evaluator.get_scores(candidate, references)
162161
rouge_1_f += scores["rouge-1"]["f"]
163162
rouge_2_f += scores["rouge-2"]["f"]
164163
rouge_l_f += scores["rouge-l"]["f"]
165-
166164
assert pytest.approx(engine.state.metrics["Rouge-1-F"], abs=1e-4) == rouge_1_f / len(data)
167165
assert pytest.approx(engine.state.metrics["Rouge-2-F"], abs=1e-4) == rouge_2_f / len(data)
168166
assert pytest.approx(engine.state.metrics["Rouge-L-F"], abs=1e-4) == rouge_l_f / len(data)

0 commit comments

Comments
 (0)