Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from fsspec.core import url_to_fs

from . import config
from .features import Audio, Features, Image, Value, Video
from .features import Audio, Features, Image, Pdf, Value, Video
from .features.features import (
FeatureType,
_ArrayXDExtensionType,
Expand All @@ -42,7 +42,7 @@
from .keyhash import DuplicatedKeysError, KeyHasher
from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
from .utils import logging
from .utils.py_utils import asdict, first_non_null_value
from .utils.py_utils import asdict, first_non_null_non_empty_value


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -189,9 +189,23 @@ def _infer_custom_type_and_encode(data: Iterable) -> tuple[Iterable, Optional[Fe
if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image

non_null_idx, non_null_value = first_non_null_value(data)
non_null_idx, non_null_value = first_non_null_non_empty_value(data)
if isinstance(non_null_value, PIL.Image.Image):
return [Image().encode_example(value) if value is not None else None for value in data], Image()
if isinstance(non_null_value, list) and isinstance(non_null_value[0], PIL.Image.Image):
return [[Image().encode_example(x) for x in value] if value is not None else None for value in data], [
Image()
]
if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
import pdfplumber

non_null_idx, non_null_value = first_non_null_non_empty_value(data)
if isinstance(non_null_value, pdfplumber.pdf.PDF):
return [Pdf().encode_example(value) if value is not None else None for value in data], Pdf()
if isinstance(non_null_value, list) and isinstance(non_null_value[0], pdfplumber.pdf.PDF):
return [[Pdf().encode_example(x) for x in value] if value is not None else None for value in data], [
Pdf()
]
return data, None

def __arrow_array__(self, type: Optional[pa.DataType] = None):
Expand Down Expand Up @@ -222,7 +236,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
# efficient np array to pyarrow array
if isinstance(data, np.ndarray):
out = numpy_to_pyarrow_listarray(data)
elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], np.ndarray):
elif isinstance(data, list) and data and isinstance(first_non_null_non_empty_value(data)[1], np.ndarray):
out = list_of_np_array_to_pyarrow_listarray(data)
else:
trying_cast_to_python_objects = True
Expand Down
7 changes: 6 additions & 1 deletion src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
from .audio import Audio
from .image import Image, encode_pil_image
from .pdf import Pdf
from .pdf import Pdf, encode_pdfplumber_pdf
from .translation import Translation, TranslationVariableLanguages
from .video import Video

Expand Down Expand Up @@ -299,6 +299,9 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image

if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
import pdfplumber

if isinstance(obj, np.ndarray):
if obj.ndim == 0:
return obj[()], True
Expand Down Expand Up @@ -367,6 +370,8 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
)
elif config.PIL_AVAILABLE and "PIL" in sys.modules and isinstance(obj, PIL.Image.Image):
return encode_pil_image(obj), True
elif config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules and isinstance(obj, pdfplumber.pdf.PDF):
return encode_pdfplumber_pdf(obj), True
elif isinstance(obj, pd.Series):
return (
_cast_to_python_objects(
Expand Down
43 changes: 22 additions & 21 deletions src/datasets/features/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -
return {"path": None, "bytes": value}
elif pdfplumber is not None and isinstance(value, pdfplumber.pdf.PDF):
# convert the pdfplumber.pdf.PDF to bytes
return self.encode_pdfplumber_pdf(value)
return encode_pdfplumber_pdf(value)
elif value.get("path") is not None and os.path.isfile(value["path"]):
# we set "bytes": None to not duplicate the data if they're already available locally
return {"bytes": None, "path": value.get("path")}
Expand All @@ -108,26 +108,6 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -
f"A pdf sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
)

def encode_pdfplumber_pdf(pdf: "pdfplumber.pdf.PDF") -> dict:
"""
Encode a pdfplumber.pdf.PDF object into a dictionary.

If the PDF has an associated file path, returns the path. Otherwise, serializes
the PDF content into bytes.

Args:
pdf (pdfplumber.pdf.PDF): A pdfplumber PDF object.

Returns:
dict: A dictionary with "path" or "bytes" field.
"""
if hasattr(pdf, "stream") and hasattr(pdf.stream, "name") and pdf.stream.name:
# Return the path if the PDF has an associated file path
return {"path": pdf.stream.name, "bytes": None}
else:
# Convert the PDF to bytes if no path is available
return {"path": None, "bytes": pdf_to_bytes(pdf)}

def decode_example(self, value: dict, token_per_repo_id=None) -> "pdfplumber.pdf.PDF":
"""Decode example pdf file into pdf data.

Expand Down Expand Up @@ -235,3 +215,24 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
return array_cast(storage, self.pa_type)


def encode_pdfplumber_pdf(pdf: "pdfplumber.pdf.PDF") -> dict:
"""
Encode a pdfplumber.pdf.PDF object into a dictionary.

If the PDF has an associated file path, returns the path. Otherwise, serializes
the PDF content into bytes.

Args:
pdf (pdfplumber.pdf.PDF): A pdfplumber PDF object.

Returns:
dict: A dictionary with "path" or "bytes" field.
"""
if hasattr(pdf, "stream") and hasattr(pdf.stream, "name") and pdf.stream.name:
# Return the path if the PDF has an associated file path
return {"path": pdf.stream.name, "bytes": None}
else:
# Convert the PDF to bytes if no path is available
return {"path": None, "bytes": pdf_to_bytes(pdf)}
8 changes: 8 additions & 0 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ def first_non_null_value(iterable):
return -1, None


def first_non_null_non_empty_value(iterable):
"""Return the index and the value of the first non-null non-empty value in the iterable. If all values are None or empty, return -1 as index."""
for i, value in enumerate(iterable):
if value is not None and not (isinstance(value, (dict, list)) and len(value) == 0):
return i, value
return -1, None


def zip_dict(*dicts):
"""Iterate over items of dictionaries grouped by their keys."""
for key in unique_values(itertools.chain(*dicts)): # set merge all keys
Expand Down
Loading