From e11e77e51beab29e2a44ab27c44146facf4d2662 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Sat, 7 Aug 2021 16:55:49 +0200 Subject: [PATCH 1/2] Support fast tokenizer in BertScore --- metrics/bertscore/bertscore.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/metrics/bertscore/bertscore.py b/metrics/bertscore/bertscore.py index 17c9d908885..953552d5790 100644 --- a/metrics/bertscore/bertscore.py +++ b/metrics/bertscore/bertscore.py @@ -14,9 +14,11 @@ # limitations under the License. """ BERTScore metric. """ +import functools from contextlib import contextmanager import bert_score +from packaging import version import datasets @@ -131,7 +133,20 @@ def _compute( all_layers=False, rescale_with_baseline=False, baseline_path=None, + use_fast_tokenizer=False, ): + get_hash = bert_score.utils.get_hash + scorer = bert_score.BERTScorer + + if version.parse(bert_score.__version__) >= version.parse("0.3.10"): + get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer) + scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer) + elif use_fast_tokenizer: + raise ImportWarning( + "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of `bert-score` doesn't match this condition.\n" + 'You can install it with `pip install "bert-score>=0.3.10"`.' + ) + if model_type is None: assert lang is not None, "either lang or model_type should be specified" model_type = bert_score.utils.lang2model[lang.lower()] @@ -139,7 +154,7 @@ def _compute( if num_layers is None: num_layers = bert_score.utils.model2layers[model_type] - hashcode = bert_score.utils.get_hash( + hashcode = get_hash( model=model_type, num_layers=num_layers, idf=idf, @@ -149,7 +164,7 @@ def _compute( with filter_logging_context(): if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode: - self.cached_bertscorer = bert_score.BERTScorer( + self.cached_bertscorer = scorer( model_type=model_type, num_layers=num_layers, batch_size=batch_size, From f9544cf1a1670bbe7e4906599caa8f611fcd5ff7 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Mon, 9 Aug 2021 12:17:22 +0200 Subject: [PATCH 2/2] Mention new arg in _KWARGS_DESCRIPTION --- metrics/bertscore/bertscore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metrics/bertscore/bertscore.py b/metrics/bertscore/bertscore.py index 953552d5790..9247d622185 100644 --- a/metrics/bertscore/bertscore.py +++ b/metrics/bertscore/bertscore.py @@ -79,6 +79,7 @@ def filter_log(record): specified when `rescale_with_baseline` is True. rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. baseline_path (str): Customized baseline file. + use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10. Returns: precision: Precision.