From 52403bacc2e14759d0fe439e71f92c345eda8436 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 18 Oct 2022 17:27:48 +0200 Subject: [PATCH] Avoid extra cast in `class_encode_column` --- src/datasets/arrow_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 909f666d4dc..048f0d5abf7 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1559,16 +1559,16 @@ def cast_to_class_labels(batch): ] return batch + new_features = dset.features.copy() + new_features[column] = dst_feat + dset = dset.map( cast_to_class_labels, batched=True, + features=new_features, desc="Casting to class labels", ) - new_features = dset.features.copy() - new_features[column] = dst_feat - dset = dset.cast(new_features) - return dset @fingerprint_transform(inplace=False)