diff --git a/examples/flax/_tests_requirements.txt b/examples/flax/_tests_requirements.txt index f9de455f62bf..f1e0fb2d9071 100644 --- a/examples/flax/_tests_requirements.txt +++ b/examples/flax/_tests_requirements.txt @@ -4,4 +4,5 @@ conllu nltk rouge-score seqeval -tensorboard \ No newline at end of file +tensorboard +evaluate >= 0.2.0 \ No newline at end of file diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index a4deab8041b2..4fe144db8b17 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -31,10 +31,11 @@ import datasets import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import Dataset, load_dataset, load_metric +from datasets import Dataset, load_dataset from PIL import Image from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -811,7 +812,7 @@ def blockwise_data_loader( yield batch # Metric - metric = load_metric("rouge") + metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index bc3b5acc50b5..0873b19413bf 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -32,9 +32,10 @@ import datasets import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -776,7 +777,7 @@ def post_processing_function(examples, features, predictions, stage="eval"): references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) - metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") + metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index bd17141a44b8..d6f8ec78bab9 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -33,9 +33,10 @@ import datasets import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import Dataset, load_dataset, load_metric +from datasets import Dataset, load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -656,7 +657,7 @@ def preprocess_function(examples): ) # Metric - metric = load_metric("rouge") + metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 84e1c8512565..7f5524dbb437 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -27,9 +27,10 @@ import datasets import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -570,9 +571,9 @@ def eval_step(state, batch): p_eval_step = jax.pmap(eval_step, axis_name="batch") if data_args.task_name is not None: - metric = load_metric("glue", data_args.task_name) + metric = evaluate.load("glue", data_args.task_name) else: - metric = load_metric("accuracy") + metric = evaluate.load("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 79e1589e3fbd..0a66b5f1990b 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -29,9 +29,10 @@ import datasets import numpy as np -from datasets import ClassLabel, load_dataset, load_metric +from datasets import ClassLabel, load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -646,7 +647,7 @@ def eval_step(state, batch): p_eval_step = jax.pmap(eval_step, axis_name="batch") - metric = load_metric("seqeval") + metric = evaluate.load("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays