Skip to content

Commit 22f62f6

Browse files
giraffacarplhoestq
andauthored
fix: Image Feature in Datasets Library Fails to Handle bytearray Objects from Spark DataFrames (#7517) (#7521)
* add bytearray to features encode_example methods * add spark decode test --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 83cc147 commit 22f62f6

File tree

6 files changed

+47
-10
lines changed

6 files changed

+47
-10
lines changed

src/datasets/features/audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class Audio:
7474
def __call__(self):
7575
return self.pa_type
7676

77-
def encode_example(self, value: Union[str, bytes, dict]) -> dict:
77+
def encode_example(self, value: Union[str, bytes, bytearray, dict]) -> dict:
7878
"""Encode example into a format for Arrow.
7979
8080
Args:
@@ -90,7 +90,7 @@ def encode_example(self, value: Union[str, bytes, dict]) -> dict:
9090
raise ImportError("To support encoding audio data, please install 'soundfile'.") from err
9191
if isinstance(value, str):
9292
return {"bytes": None, "path": value}
93-
elif isinstance(value, bytes):
93+
elif isinstance(value, (bytes, bytearray)):
9494
return {"bytes": value, "path": None}
9595
elif "array" in value:
9696
# convert the audio array to wav bytes

src/datasets/features/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class Image:
9191
def __call__(self):
9292
return self.pa_type
9393

94-
def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.Image"]) -> dict:
94+
def encode_example(self, value: Union[str, bytes, bytearray, dict, np.ndarray, "PIL.Image.Image"]) -> dict:
9595
"""Encode example into a format for Arrow.
9696
9797
Args:
@@ -111,7 +111,7 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.I
111111

112112
if isinstance(value, str):
113113
return {"path": value, "bytes": None}
114-
elif isinstance(value, bytes):
114+
elif isinstance(value, (bytes, bytearray)):
115115
return {"path": None, "bytes": value}
116116
elif isinstance(value, np.ndarray):
117117
# convert the image array to PNG/TIFF bytes

src/datasets/features/pdf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Pdf:
7575
def __call__(self):
7676
return self.pa_type
7777

78-
def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -> dict:
78+
def encode_example(self, value: Union[str, bytes, bytearray, dict, "pdfplumber.pdf.PDF"]) -> dict:
7979
"""Encode example into a format for Arrow.
8080
8181
Args:
@@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -
9292

9393
if isinstance(value, str):
9494
return {"path": value, "bytes": None}
95-
elif isinstance(value, bytes):
95+
elif isinstance(value, (bytes, bytearray)):
9696
return {"path": None, "bytes": value}
9797
elif pdfplumber is not None and isinstance(value, pdfplumber.pdf.PDF):
9898
# convert the pdfplumber.pdf.PDF to bytes

src/datasets/features/video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class Video:
7171
def __call__(self):
7272
return self.pa_type
7373

74-
def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoReader"]) -> Example:
74+
def encode_example(self, value: Union[str, bytes, bytearray, Example, np.ndarray, "VideoReader"]) -> Example:
7575
"""Encode example into a format for Arrow.
7676
7777
Args:
@@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoRea
9292

9393
if isinstance(value, str):
9494
return {"path": value, "bytes": None}
95-
elif isinstance(value, bytes):
95+
elif isinstance(value, (bytes, bytearray)):
9696
return {"path": None, "bytes": value}
9797
elif isinstance(value, np.ndarray):
9898
# convert the video array to bytes

src/datasets/keyhash.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
from huggingface_hub.utils import insecure_hashlib
3636

3737

38-
def _as_bytes(hash_data: Union[str, int, bytes]) -> bytes:
38+
def _as_bytes(hash_data: Union[str, int, bytes, bytearray]) -> bytes:
3939
"""
4040
Returns the input hash_data in its bytes form
4141
4242
Args:
4343
hash_data: the hash salt/key to be converted to bytes
4444
"""
45-
if isinstance(hash_data, bytes):
45+
if isinstance(hash_data, (bytes, bytearray)):
4646
# Data already in bytes, returns as it as
4747
return hash_data
4848
elif isinstance(hash_data, str):

tests/packaged_modules/test_spark.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest.mock import patch
22

3+
import numpy as np
34
import pyspark
45
import pytest
56

7+
from datasets import Features, Image, IterableDataset
68
from datasets.builder import InvalidConfigName
79
from datasets.data_files import DataFilesList
810
from datasets.packaged_modules.spark.spark import (
@@ -131,3 +133,38 @@ def test_repartition_df_if_needed_max_num_df_rows():
131133
spark_builder._repartition_df_if_needed(max_shard_size=1)
132134
# The new number of partitions should not be greater than the number of rows.
133135
assert spark_builder.df.rdd.getNumPartitions() == 100
136+
137+
138+
@require_not_windows
139+
@require_dill_gt_0_3_2
140+
def test_iterable_image_features():
141+
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
142+
img_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
143+
data = [(img_bytes,)]
144+
df = spark.createDataFrame(data, "image: binary")
145+
features = Features({"image": Image(decode=False)})
146+
dset = IterableDataset.from_spark(df, features=features)
147+
item = next(iter(dset))
148+
assert item.keys() == {"image"}
149+
assert item == {"image": {"path": None, "bytes": img_bytes}}
150+
151+
152+
@require_not_windows
153+
@require_dill_gt_0_3_2
154+
def test_iterable_image_features_decode():
155+
from io import BytesIO
156+
157+
import PIL.Image
158+
159+
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
160+
img = PIL.Image.fromarray(np.zeros((10, 10, 3), dtype=np.uint8), "RGB")
161+
buffer = BytesIO()
162+
img.save(buffer, format="PNG")
163+
img_bytes = bytes(buffer.getvalue())
164+
data = [(img_bytes,)]
165+
df = spark.createDataFrame(data, "image: binary")
166+
features = Features({"image": Image()})
167+
dset = IterableDataset.from_spark(df, features=features)
168+
item = next(iter(dset))
169+
assert item.keys() == {"image"}
170+
assert isinstance(item["image"], PIL.Image.Image)

0 commit comments

Comments
 (0)