diff --git a/metrics/bertscore/bertscore.py b/metrics/bertscore/bertscore.py index 17c9d908885..9247d622185 100644 --- a/metrics/bertscore/bertscore.py +++ b/metrics/bertscore/bertscore.py @@ -14,9 +14,11 @@ # limitations under the License. """ BERTScore metric. """ +import functools from contextlib import contextmanager import bert_score +from packaging import version import datasets @@ -77,6 +79,7 @@ def filter_log(record): specified when `rescale_with_baseline` is True. rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. baseline_path (str): Customized baseline file. + use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10. Returns: precision: Precision. @@ -131,7 +134,20 @@ def _compute( all_layers=False, rescale_with_baseline=False, baseline_path=None, + use_fast_tokenizer=False, ): + get_hash = bert_score.utils.get_hash + scorer = bert_score.BERTScorer + + if version.parse(bert_score.__version__) >= version.parse("0.3.10"): + get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer) + scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer) + elif use_fast_tokenizer: + raise ImportWarning( + "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of `bert-score` doesn't match this condition.\n" + 'You can install it with `pip install "bert-score>=0.3.10"`.' + ) + if model_type is None: assert lang is not None, "either lang or model_type should be specified" model_type = bert_score.utils.lang2model[lang.lower()] @@ -139,7 +155,7 @@ def _compute( if num_layers is None: num_layers = bert_score.utils.model2layers[model_type] - hashcode = bert_score.utils.get_hash( + hashcode = get_hash( model=model_type, num_layers=num_layers, idf=idf, @@ -149,7 +165,7 @@ def _compute( with filter_logging_context(): if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode: - self.cached_bertscorer = bert_score.BERTScorer( + self.cached_bertscorer = scorer( model_type=model_type, num_layers=num_layers, batch_size=batch_size,