Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions src/datasets/features/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..download.download_config import DownloadConfig
from ..table import array_cast
from ..utils.file_utils import is_local_path, xopen
from ..utils.py_utils import string_to_dict
from ..utils.py_utils import no_op_if_value_is_null, string_to_dict


if TYPE_CHECKING:
Expand Down Expand Up @@ -125,9 +125,6 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
Returns:
`nibabel.Nifti1Image` objects
"""
if not self.decode:
raise NotImplementedError("Decoding is disabled for this feature. Please use Nifti(decode=True) instead.")

if config.NIBABEL_AVAILABLE:
import nibabel as nib
else:
Expand All @@ -141,6 +138,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
if path is None:
raise ValueError(f"A nifti should have one of 'path' or 'bytes' but both are None in {value}.")
else:
# gzipped files have the structure: 'gzip://T1.nii::<local_path>'
if path.startswith("gzip://") and is_local_path(path.split("::")[-1]):
path = path.split("::")[-1]
if is_local_path(path):
nifti = nib.load(path)
else:
Expand All @@ -150,11 +150,10 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
if source_url.startswith(config.HF_ENDPOINT)
else config.HUB_DATASETS_HFFS_URL
)
try:
repo_id = string_to_dict(source_url, pattern)["repo_id"]
token = token_per_repo_id.get(repo_id)
except ValueError:
token = None
source_url_fields = string_to_dict(source_url, pattern)
token = (
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
)
download_config = DownloadConfig(token=token)
with xopen(path, "rb", download_config=download_config) as f:
nifti = nib.load(f)
Expand All @@ -172,6 +171,46 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif

return nifti

def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
"""Embed NifTI files into the Arrow array.

Args:
storage (`pa.StructArray`):
PyArrow array to embed.

Returns:
`pa.StructArray`: Array in the NifTI arrow storage type, that is
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
"""
if token_per_repo_id is None:
token_per_repo_id = {}

@no_op_if_value_is_null
def path_to_bytes(path):
source_url = path.split("::")[-1]
pattern = (
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
)
source_url_fields = string_to_dict(source_url, pattern)
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
download_config = DownloadConfig(token=token)
with xopen(path, "rb", download_config=download_config) as f:
return f.read()

bytes_array = pa.array(
[
(path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None
for x in storage.to_pylist()
],
type=pa.binary(),
)
path_array = pa.array(
[os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()],
type=pa.string(),
)
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
return array_cast(storage, self.pa_type)

def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]:
"""If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary."""
from .features import Value
Expand Down
41 changes: 40 additions & 1 deletion tests/features/test_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from pathlib import Path

import pyarrow as pa
import pytest

from datasets import Dataset, Features, Nifti
from datasets import Dataset, Features, Nifti, load_dataset
from src.datasets.features.nifti import encode_nibabel_image

from ..utils import require_nibabel
Expand Down Expand Up @@ -89,3 +90,41 @@ def test_encode_nibabel_image(shared_datadir):
assert isinstance(encoded_example_bytes, dict)
assert encoded_example_bytes["bytes"] is not None and encoded_example_bytes["path"] is None
# this cannot be converted back from bytes (yet)


@require_nibabel
def test_embed_storage(shared_datadir):
from io import BytesIO

import nibabel as nib

nifti_path = str(shared_datadir / "test_nifti.nii")
img = nib.load(nifti_path)
nifti = Nifti()

bytes_array = pa.array([None], type=pa.binary())
path_array = pa.array([nifti_path], type=pa.string())
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"])

embedded_storage = nifti.embed_storage(storage)

embedded_bytes = embedded_storage[0]["bytes"].as_py()

bio = BytesIO(embedded_bytes)
fh = nib.FileHolder(fileobj=bio)
nifti_img = nib.Nifti1Image.from_file_map({"header": fh, "image": fh})

assert embedded_bytes is not None
assert nifti_img.header == img.header
assert (nifti_img.affine == img.affine).all()
assert (nifti_img.get_fdata() == img.get_fdata()).all()


@require_nibabel
def test_load_zipped_file_locally(shared_datadir):
import nibabel as nib

nifti_path = str(shared_datadir / "test_nifti.nii.gz")

ds = load_dataset("niftifolder", data_files=nifti_path)
assert isinstance(ds["train"][0]["nifti"], nib.nifti1.Nifti1Image)
Loading