Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 118 additions & 60 deletions tests/ignite/metrics/nlp/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from collections import Counter

import numpy as np
import pytest
import torch
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
Expand All @@ -15,6 +16,14 @@
corpus = CorpusForTest(lower_split=True)


def to_float32_if_mps(x, device):
if isinstance(x, torch.Tensor) and device == "mps" and x.dtype == torch.float64:
return x.to(torch.float32)
elif isinstance(x, np.ndarray) and device == "mps" and x.dtype == np.float64:
return x.astype(np.float32)
return x


def test_wrong_inputs():
with pytest.raises(ValueError, match=r"ngram order must be greater than zero"):
Bleu(ngram=0)
Expand Down Expand Up @@ -44,101 +53,136 @@ def test_wrong_inputs():
)


def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8):
def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8, device="cpu"):

candidates = to_float32_if_mps(candidates, device)
references = to_float32_if_mps(references, device)

for i in range(1, ngram_range):
weights = tuple([1 / i] * i)
bleu = Bleu(ngram=i, average=average, smooth=smooth)
bleu = Bleu(ngram=i, average=average, smooth=smooth, device=device)

if average == "macro":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
reference = sentence_bleu(
references[0], candidates[0], weights=weights, smoothing_function=smooth_nltk_fn
)
assert pytest.approx(reference) == bleu._sentence_bleu(references[0], candidates[0])
computed = bleu._sentence_bleu(references[0], candidates[0])
if isinstance(computed, torch.Tensor):
computed = computed.cpu().float().item()
assert np.allclose(computed, reference, rtol=1e-6)

elif average == "micro":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
reference = corpus_bleu(references, candidates, weights=weights, smoothing_function=smooth_nltk_fn)
assert pytest.approx(reference) == bleu._corpus_bleu(references, candidates)
computed = bleu._corpus_bleu(references, candidates)
if isinstance(computed, torch.Tensor):
computed = computed.cpu().float().item()
assert np.allclose(computed, reference, rtol=1e-6)

bleu.update((candidates, references))
assert pytest.approx(reference) == bleu.compute()
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().float().item()
assert np.allclose(computed, reference, rtol=1e-6)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu(candidates, references):
_test(candidates, references, "macro")
def test_macro_bleu(candidates, references, available_device):
_test(candidates, references, "macro", device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu(candidates, references):
_test(candidates, references, "micro")
def test_micro_bleu(candidates, references, available_device):
_test(candidates, references, "micro", device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_smooth1(candidates, references):
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1)
def test_macro_bleu_smooth1(candidates, references, available_device):
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_smooth1(candidates, references):
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1)
def test_micro_bleu_smooth1(candidates, references, available_device):
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_nltk_smooth2(candidates, references):
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2)
def test_macro_bleu_nltk_smooth2(candidates, references, available_device):
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_nltk_smooth2(candidates, references):
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2)
def test_micro_bleu_nltk_smooth2(candidates, references, available_device):
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_smooth2(candidates, references):
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3)
def test_macro_bleu_smooth2(candidates, references, available_device):
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3, available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_smooth2(candidates, references):
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3)


def test_accumulation_macro_bleu():
bleu = Bleu(ngram=4, smooth="smooth2")
bleu.update(([corpus.cand_1], [corpus.references_1]))
bleu.update(([corpus.cand_2a], [corpus.references_2]))
bleu.update(([corpus.cand_2b], [corpus.references_2]))
bleu.update(([corpus.cand_3], [corpus.references_2]))
value = bleu._sentence_bleu(corpus.references_1, corpus.cand_1)
value += bleu._sentence_bleu(corpus.references_2, corpus.cand_2a)
value += bleu._sentence_bleu(corpus.references_2, corpus.cand_2b)
value += bleu._sentence_bleu(corpus.references_2, corpus.cand_3)
assert bleu.compute() == value / 4


def test_accumulation_micro_bleu():
bleu = Bleu(ngram=4, smooth="smooth2", average="micro")
bleu.update(([corpus.cand_1], [corpus.references_1]))
bleu.update(([corpus.cand_2a], [corpus.references_2]))
bleu.update(([corpus.cand_2b], [corpus.references_2]))
bleu.update(([corpus.cand_3], [corpus.references_2]))
def test_micro_bleu_smooth2(candidates, references, available_device):
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3, device=available_device)


def test_accumulation_macro_bleu(available_device):
bleu = Bleu(ngram=4, smooth="smooth2", device=available_device)
assert bleu._device == torch.device(available_device)
cand_1 = to_float32_if_mps(corpus.cand_1, available_device)
cand_2a = to_float32_if_mps(corpus.cand_2a, available_device)
cand_2b = to_float32_if_mps(corpus.cand_2b, available_device)
cand_3 = to_float32_if_mps(corpus.cand_3, available_device)
ref_1 = to_float32_if_mps(corpus.references_1, available_device)
ref_2 = to_float32_if_mps(corpus.references_2, available_device)

bleu.update(([cand_1], [ref_1]))
bleu.update(([cand_2a], [ref_2]))
bleu.update(([cand_2b], [ref_2]))
bleu.update(([cand_3], [ref_2]))
value = bleu._sentence_bleu(ref_1, cand_1)
value += bleu._sentence_bleu(ref_2, cand_2a)
value += bleu._sentence_bleu(ref_2, cand_2b)
value += bleu._sentence_bleu(ref_2, cand_3)
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().float().item()
assert np.allclose(computed, value / 4, rtol=1e-6)


def test_accumulation_micro_bleu(available_device):
bleu = Bleu(ngram=4, smooth="smooth2", average="micro", device=available_device)
assert bleu._device == torch.device(available_device)
cand_1 = to_float32_if_mps(corpus.cand_1, available_device)
cand_2a = to_float32_if_mps(corpus.cand_2a, available_device)
cand_2b = to_float32_if_mps(corpus.cand_2b, available_device)
cand_3 = to_float32_if_mps(corpus.cand_3, available_device)
ref_1 = to_float32_if_mps(corpus.references_1, available_device)
ref_2 = to_float32_if_mps(corpus.references_2, available_device)

bleu.update(([cand_1], [ref_1]))
bleu.update(([cand_2a], [ref_2]))
bleu.update(([cand_2b], [ref_2]))
bleu.update(([cand_3], [ref_2]))
value = bleu._corpus_bleu(
[corpus.references_1, corpus.references_2, corpus.references_2, corpus.references_2],
[corpus.cand_1, corpus.cand_2a, corpus.cand_2b, corpus.cand_3],
[ref_1, ref_2, ref_2, ref_2],
[cand_1, cand_2a, cand_2b, cand_3],
)
assert bleu.compute() == value


def test_bleu_batch_macro():
bleu = Bleu(ngram=4)
def test_bleu_batch_macro(available_device):
bleu = Bleu(ngram=4, device=available_device)
assert bleu._device == torch.device(available_device)

# Batch size 3
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
refs = [corpus.references_1, corpus.references_2, corpus.references_2]
hypotheses = [to_float32_if_mps(c, available_device) for c in [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]]
refs = [
to_float32_if_mps(r, available_device) for r in [corpus.references_1, corpus.references_2, corpus.references_2]
]
bleu.update((hypotheses, refs))

with warnings.catch_warnings():
Expand All @@ -148,7 +192,11 @@ def test_bleu_batch_macro():
+ sentence_bleu(refs[1], hypotheses[1])
+ sentence_bleu(refs[2], hypotheses[2])
) / 3
assert pytest.approx(bleu.compute()) == reference_bleu_score
reference_bleu_score = np.float32(reference_bleu_score)
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().float().item()
assert np.allclose(computed, reference_bleu_score, rtol=1e-6)

value = 0
for _hypotheses, _refs in zip(hypotheses, refs):
Expand All @@ -158,16 +206,19 @@ def test_bleu_batch_macro():
ref_1 = value / len(refs)
ref_2 = bleu.compute()

assert pytest.approx(ref_1) == reference_bleu_score
assert pytest.approx(ref_2) == reference_bleu_score
assert np.allclose(ref_1, reference_bleu_score, rtol=1e-6)
assert np.allclose(ref_2, reference_bleu_score, rtol=1e-6)


def test_bleu_batch_micro():
bleu = Bleu(ngram=4, average="micro")
def test_bleu_batch_micro(available_device):
bleu = Bleu(ngram=4, average="micro", device=available_device)
assert bleu._device == torch.device(available_device)

# Batch size 3
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
refs = [corpus.references_1, corpus.references_2, corpus.references_2]
hypotheses = [to_float32_if_mps(c, available_device) for c in [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]]
refs = [
to_float32_if_mps(r, available_device) for r in [corpus.references_1, corpus.references_2, corpus.references_2]
]
bleu.update((hypotheses, refs))

with warnings.catch_warnings():
Expand All @@ -187,9 +238,16 @@ def test_bleu_batch_micro():
(corpus.cand_1, corpus.references_1),
],
)
def test_n_gram_counter(candidates, references):
bleu = Bleu(ngram=4)
def test_n_gram_counter(candidates, references, available_device):
bleu = Bleu(ngram=4, device=available_device)
assert bleu._device == torch.device(available_device)

candidates = to_float32_if_mps(candidates, available_device)
references = to_float32_if_mps(references, available_device)

hyp_length, ref_length = bleu._n_gram_counter([references], [candidates], Counter(), Counter())
hyp_length = int(hyp_length)
ref_length = int(ref_length)
assert hyp_length == len(candidates)

ref_lens = (len(reference) for reference in references)
Expand All @@ -212,9 +270,9 @@ def _test_macro_distrib_integration(device):
def update(_, i):
return data[i + size * rank]

def _test(metric_device):
def _test(device):
engine = Engine(update)
m = Bleu(ngram=4, smooth="smooth2")
m = Bleu(ngram=4, smooth="smooth2", device=device)
m.attach(engine, "bleu")

engine.run(data=list(range(size)), max_epochs=1)
Expand Down Expand Up @@ -256,7 +314,7 @@ def update(_, i):

def _test(metric_device):
engine = Engine(update)
m = Bleu(ngram=4, smooth="smooth2", average="micro")
m = Bleu(ngram=4, smooth="smooth2", average="micro", device=metric_device)
m.attach(engine, "bleu")

engine.run(data=list(range(size)), max_epochs=1)
Expand Down
9 changes: 5 additions & 4 deletions tests/ignite/metrics/nlp/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def test_wrong_inputs():
(2, "abcdef", "zbdfz", (0, 0)),
],
)
def test_rouge_n_alpha(ngram, candidate, reference, expected):
def test_rouge_n_alpha(ngram, candidate, reference, expected, available_device):
for alpha in [0, 1, 0.3, 0.5, 0.8]:
rouge = RougeN(ngram=ngram, alpha=alpha)
rouge = RougeN(ngram=ngram, alpha=alpha, device=available_device)
rouge.update(([candidate], [[reference]]))
results = rouge.compute()
assert results[f"Rouge-{ngram}-P"] == expected[0]
Expand All @@ -101,7 +101,7 @@ def test_rouge_n_alpha(ngram, candidate, reference, expected):
@pytest.mark.parametrize(
"candidates, references", [corpus.sample_1, corpus.sample_2, corpus.sample_3, corpus.sample_4, corpus.sample_5]
)
def test_rouge_metrics(candidates, references):
def test_rouge_metrics(candidates, references, available_device):
for multiref in ["average", "best"]:
# PERL 1.5.5 reference
apply_avg = multiref == "average"
Expand All @@ -123,7 +123,8 @@ def test_rouge_metrics(candidates, references):

lower_split_candidates = [candidate.lower().split() for candidate in candidates]

m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5)
m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5, device=available_device)
assert m._device == torch.device(available_device)
m.update((lower_split_candidates, lower_split_references))
results = m.compute()

Expand Down
Loading