Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: uv pip install --system --upgrade pyarrow huggingface-hub dill
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: uv pip install --system pyarrow==12.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
run: uv pip install --system pyarrow==15.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17",
# Backend and serialization.
# Minimum 12.0.0 to be able to concatenate extension arrays
"pyarrow>=12.0.0",
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
"pyarrow>=15.0.0",
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
"pyarrow-hotfix",
# For smart caching dataset processing
Expand Down Expand Up @@ -166,7 +166,6 @@
"pytest-datadir",
"pytest-xdist",
# optional dependencies
"apache-beam>=2.26.0; sys_platform != 'win32' and python_version<'3.10'", # doesn't support recent dill versions for recent python versions and on windows requires pyarrow<12.0.0
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"jax>=0.3.14; sys_platform != 'win32'",
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,11 @@ def __init__(
f"{e}\nThe 'source' features come from dataset_info.json, and the 'target' ones are those of the dataset arrow file."
)

# In case there are types like pa.dictionary that we need to convert to the underlying type

if self.data.schema != self.info.features.arrow_schema:
self._data = self.data.cast(self.info.features.arrow_schema)

# Infer fingerprint if None

if self._fingerprint is None:
Expand Down
7 changes: 5 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
return "string"
elif pyarrow.types.is_large_string(arrow_type):
return "large_string"
elif pyarrow.types.is_dictionary(arrow_type):
return _arrow_to_datasets_dtype(arrow_type.value_type)
else:
raise ValueError(f"Arrow type {arrow_type} does not have a datasets dtype equivalent.")

Expand Down Expand Up @@ -1434,8 +1436,6 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
elif isinstance(pa_type, _ArrayXDExtensionType):
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
elif isinstance(pa_type, pa.DictionaryType):
raise NotImplementedError # TODO(thom) this will need access to the dictionary as well (for labels). I.e. to the py_table
elif isinstance(pa_type, pa.DataType):
return Value(dtype=_arrow_to_datasets_dtype(pa_type))
else:
Expand Down Expand Up @@ -1705,6 +1705,9 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features":
It also checks the schema metadata for Hugging Face Datasets features.
Non-nullable fields are not supported and set to nullable.

Also, pa.dictionary is not supported and it uses its underlying type instead.
Therefore datasets convert DictionaryArray objects to their actual values.

Args:
pa_schema (`pyarrow.Schema`):
Arrow Schema.
Expand Down
6 changes: 6 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def test_string_to_arrow_bijection_for_primitive_types(self):
with self.assertRaises(ValueError):
string_to_arrow(sdt)

def test_categorical_one_way(self):
# Categorical types (aka dictionary types) need special handling as there isn't a bijection
categorical_type = pa.dictionary(pa.int32(), pa.string())

self.assertEqual("string", _arrow_to_datasets_dtype(categorical_type))

def test_feature_named_type(self):
"""reference: issue #1110"""
features = Features({"_type": Value("string")})
Expand Down
18 changes: 18 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4826,3 +4826,21 @@ def test_dataset_getitem_raises():
ds[False]
with pytest.raises(TypeError):
ds._getitem(True)


def test_categorical_dataset(tmpdir):
n_legs = pa.array([2, 4, 5, 100])
animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]).cast(
pa.dictionary(pa.int32(), pa.string())
)
names = ["n_legs", "animals"]

table = pa.Table.from_arrays([n_legs, animals], names=names)
table_path = str(tmpdir / "data.parquet")
pa.parquet.write_table(table, table_path)

dataset = Dataset.from_parquet(table_path)
entry = dataset[0]

# Categorical types get transparently converted to string
assert entry["animals"] == "Flamingo"