diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index 349d89f3e1b..82c2e7f1f46 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -74,7 +74,7 @@ class Audio: def __call__(self): return self.pa_type - def encode_example(self, value: Union[str, bytes, dict]) -> dict: + def encode_example(self, value: Union[str, bytes, bytearray, dict]) -> dict: """Encode example into a format for Arrow. Args: @@ -90,7 +90,7 @@ def encode_example(self, value: Union[str, bytes, dict]) -> dict: raise ImportError("To support encoding audio data, please install 'soundfile'.") from err if isinstance(value, str): return {"bytes": None, "path": value} - elif isinstance(value, bytes): + elif isinstance(value, (bytes, bytearray)): return {"bytes": value, "path": None} elif "array" in value: # convert the audio array to wav bytes diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 9682258b271..de18efd2719 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -91,7 +91,7 @@ class Image: def __call__(self): return self.pa_type - def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.Image"]) -> dict: + def encode_example(self, value: Union[str, bytes, bytearray, dict, np.ndarray, "PIL.Image.Image"]) -> dict: """Encode example into a format for Arrow. Args: @@ -111,7 +111,7 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.I if isinstance(value, str): return {"path": value, "bytes": None} - elif isinstance(value, bytes): + elif isinstance(value, (bytes, bytearray)): return {"path": None, "bytes": value} elif isinstance(value, np.ndarray): # convert the image array to PNG/TIFF bytes diff --git a/src/datasets/features/pdf.py b/src/datasets/features/pdf.py index 7e62c50831c..afba910387a 100644 --- a/src/datasets/features/pdf.py +++ b/src/datasets/features/pdf.py @@ -75,7 +75,7 @@ class Pdf: def __call__(self): return self.pa_type - def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -> dict: + def encode_example(self, value: Union[str, bytes, bytearray, dict, "pdfplumber.pdf.PDF"]) -> dict: """Encode example into a format for Arrow. Args: @@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) - if isinstance(value, str): return {"path": value, "bytes": None} - elif isinstance(value, bytes): + elif isinstance(value, (bytes, bytearray)): return {"path": None, "bytes": value} elif pdfplumber is not None and isinstance(value, pdfplumber.pdf.PDF): # convert the pdfplumber.pdf.PDF to bytes diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index 11eae4812b3..e8c5a0e4456 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -71,7 +71,7 @@ class Video: def __call__(self): return self.pa_type - def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoReader"]) -> Example: + def encode_example(self, value: Union[str, bytes, bytearray, Example, np.ndarray, "VideoReader"]) -> Example: """Encode example into a format for Arrow. Args: @@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoRea if isinstance(value, str): return {"path": value, "bytes": None} - elif isinstance(value, bytes): + elif isinstance(value, (bytes, bytearray)): return {"path": None, "bytes": value} elif isinstance(value, np.ndarray): # convert the video array to bytes diff --git a/src/datasets/keyhash.py b/src/datasets/keyhash.py index 3c75fcfd7ff..5ba2686e259 100644 --- a/src/datasets/keyhash.py +++ b/src/datasets/keyhash.py @@ -35,14 +35,14 @@ from huggingface_hub.utils import insecure_hashlib -def _as_bytes(hash_data: Union[str, int, bytes]) -> bytes: +def _as_bytes(hash_data: Union[str, int, bytes, bytearray]) -> bytes: """ Returns the input hash_data in its bytes form Args: hash_data: the hash salt/key to be converted to bytes """ - if isinstance(hash_data, bytes): + if isinstance(hash_data, (bytes, bytearray)): # Data already in bytes, returns as it as return hash_data elif isinstance(hash_data, str): diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index 89406c6a8dd..e7fa3aa5c9f 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -1,8 +1,10 @@ from unittest.mock import patch +import numpy as np import pyspark import pytest +from datasets import Features, Image, IterableDataset from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesList from datasets.packaged_modules.spark.spark import ( @@ -131,3 +133,38 @@ def test_repartition_df_if_needed_max_num_df_rows(): spark_builder._repartition_df_if_needed(max_shard_size=1) # The new number of partitions should not be greater than the number of rows. assert spark_builder.df.rdd.getNumPartitions() == 100 + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_iterable_image_features(): + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + img_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes() + data = [(img_bytes,)] + df = spark.createDataFrame(data, "image: binary") + features = Features({"image": Image(decode=False)}) + dset = IterableDataset.from_spark(df, features=features) + item = next(iter(dset)) + assert item.keys() == {"image"} + assert item == {"image": {"path": None, "bytes": img_bytes}} + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_iterable_image_features_decode(): + from io import BytesIO + + import PIL.Image + + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + img = PIL.Image.fromarray(np.zeros((10, 10, 3), dtype=np.uint8), "RGB") + buffer = BytesIO() + img.save(buffer, format="PNG") + img_bytes = bytes(buffer.getvalue()) + data = [(img_bytes,)] + df = spark.createDataFrame(data, "image: binary") + features = Features({"image": Image()}) + dset = IterableDataset.from_spark(df, features=features) + item = next(iter(dset)) + assert item.keys() == {"image"} + assert isinstance(item["image"], PIL.Image.Image)