Skip to content

Commit 11a2c9f

Browse files
committed
back to pyarrow 1.0.0 + raise error if using old pyarrow for parquet read/write
1 parent 077648b commit 11a2c9f

File tree

9 files changed

+59
-8
lines changed

9 files changed

+59
-8
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
- run: pip install pyarrow --upgrade
1919
- run: HF_SCRIPTS_VERSION=master python -m pytest -sv ./tests/
2020

21-
run_dataset_script_tests_pyarrow_3:
21+
run_dataset_script_tests_pyarrow_1:
2222
working_directory: ~/datasets
2323
docker:
2424
- image: circleci/python:3.6
@@ -29,7 +29,7 @@ jobs:
2929
- run: source venv/bin/activate
3030
- run: pip install .[tests]
3131
- run: pip install -r additional-tests-requirements.txt --no-deps
32-
- run: pip install pyarrow==3.0.0
32+
- run: pip install pyarrow==1.0.0
3333
- run: HF_SCRIPTS_VERSION=master python -m pytest -sv ./tests/
3434

3535
run_dataset_script_tests_pyarrow_latest_WIN:
@@ -50,7 +50,7 @@ jobs:
5050
- run: $env:HF_SCRIPTS_VERSION="master"
5151
- run: python -m pytest -sv ./tests/
5252

53-
run_dataset_script_tests_pyarrow_3_WIN:
53+
run_dataset_script_tests_pyarrow_1_WIN:
5454
working_directory: ~/datasets
5555
executor:
5656
name: win/default
@@ -64,7 +64,7 @@ jobs:
6464
- run: "& venv/Scripts/activate.ps1"
6565
- run: pip install .[tests]
6666
- run: pip install -r additional-tests-requirements.txt --no-deps
67-
- run: pip install pyarrow==3.0.0
67+
- run: pip install pyarrow==1.0.0
6868
- run: $env:HF_SCRIPTS_VERSION="master"
6969
- run: python -m pytest -sv ./tests/
7070

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@
7373
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
7474
"numpy>=1.17",
7575
# Backend and serialization.
76-
# Minimum 3.0.0 to support mix of struct and list types in parquet format
76+
# Minimum 3.0.0 to support mix of struct and list types in parquet, and batch iterators of parquet data
7777
# pyarrow 4.0.0 introduced segfault bug, see: https://github.com/huggingface/datasets/pull/2268
78-
"pyarrow>=3.0.0,!=4.0.0",
78+
"pyarrow>=1.0.0,!=4.0.0",
7979
# For smart caching dataset processing
8080
"dill",
8181
# For performance gains with apache arrow

src/datasets/io/parquet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pyarrow as pa
55
import pyarrow.parquet as pq
6+
from packaging import version
67

78
from .. import Dataset, Features, NamedSplit, config
89
from ..formatting import query_table
@@ -22,6 +23,10 @@ def __init__(
2223
keep_in_memory: bool = False,
2324
**kwargs,
2425
):
26+
if version.parse(pa.__version__) < version.parse("3.0.0"):
27+
raise ImportError(
28+
"PyArrow >= 3.0.0 is required to used the ParquetDatasetReader: pip install --upgrade pyarrow"
29+
)
2530
super().__init__(
2631
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
2732
)
@@ -66,6 +71,10 @@ def __init__(
6671
batch_size: Optional[int] = None,
6772
**parquet_writer_kwargs,
6873
):
74+
if version.parse(pa.__version__) < version.parse("3.0.0"):
75+
raise ImportError(
76+
"PyArrow >= 3.0.0 is required to used the ParquetDatasetWriter: pip install --upgrade pyarrow"
77+
)
6978
self.dataset = dataset
7079
self.path_or_buf = path_or_buf
7180
self.batch_size = batch_size

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pyarrow as pa
77
import pyarrow.parquet as pq
8+
from packaging import version
89

910
import datasets
1011

@@ -25,6 +26,10 @@ class Parquet(datasets.ArrowBasedBuilder):
2526
BUILDER_CONFIG_CLASS = ParquetConfig
2627

2728
def _info(self):
29+
if version.parse(pa.__version__) < version.parse("3.0.0"):
30+
raise ImportError(
31+
"PyArrow >= 3.0.0 is required to used the Parquet dataset builder: pip install --upgrade pyarrow"
32+
)
2833
return datasets.DatasetInfo(features=self.config.features)
2934

3035
def _split_generators(self, dl_manager):

tests/io/test_parquet.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
55
from datasets.io.parquet import ParquetDatasetReader, ParquetDatasetWriter
66

7-
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases
7+
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_pyarrow_at_least_3
88

99

1010
def _check_parquet_dataset(dataset, expected_features):
@@ -16,6 +16,7 @@ def _check_parquet_dataset(dataset, expected_features):
1616
assert dataset.features[feature].dtype == expected_dtype
1717

1818

19+
@require_pyarrow_at_least_3
1920
@pytest.mark.parametrize("keep_in_memory", [False, True])
2021
def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
2122
cache_dir = tmp_path / "cache"
@@ -25,6 +26,7 @@ def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_p
2526
_check_parquet_dataset(dataset, expected_features)
2627

2728

29+
@require_pyarrow_at_least_3
2830
@pytest.mark.parametrize(
2931
"features",
3032
[
@@ -46,6 +48,7 @@ def test_dataset_from_parquet_features(features, parquet_path, tmp_path):
4648
_check_parquet_dataset(dataset, expected_features)
4749

4850

51+
@require_pyarrow_at_least_3
4952
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
5053
def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
5154
cache_dir = tmp_path / "cache"
@@ -55,6 +58,7 @@ def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
5558
assert dataset.split == str(split) if split else "train"
5659

5760

61+
@require_pyarrow_at_least_3
5862
@pytest.mark.parametrize("path_type", [str, list])
5963
def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path):
6064
if issubclass(path_type, str):
@@ -78,6 +82,7 @@ def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",
7882
assert dataset.features[feature].dtype == expected_dtype
7983

8084

85+
@require_pyarrow_at_least_3
8186
@pytest.mark.parametrize("keep_in_memory", [False, True])
8287
def test_parquet_datasetdict_reader_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
8388
cache_dir = tmp_path / "cache"
@@ -89,6 +94,7 @@ def test_parquet_datasetdict_reader_keep_in_memory(keep_in_memory, parquet_path,
8994
_check_parquet_datasetdict(dataset, expected_features)
9095

9196

97+
@require_pyarrow_at_least_3
9298
@pytest.mark.parametrize(
9399
"features",
94100
[
@@ -110,6 +116,7 @@ def test_parquet_datasetdict_reader_features(features, parquet_path, tmp_path):
110116
_check_parquet_datasetdict(dataset, expected_features)
111117

112118

119+
@require_pyarrow_at_least_3
113120
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
114121
def test_parquet_datasetdict_reader_split(split, parquet_path, tmp_path):
115122
if split:
@@ -124,6 +131,7 @@ def test_parquet_datasetdict_reader_split(split, parquet_path, tmp_path):
124131
assert all(dataset[split].split == split for split in path.keys())
125132

126133

134+
@require_pyarrow_at_least_3
127135
def test_parquer_write(dataset, tmp_path):
128136
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet")
129137
assert writer.write() > 0

tests/test_arrow_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .utils import (
2828
assert_arrow_memory_doesnt_increase,
2929
assert_arrow_memory_increases,
30+
require_pyarrow_at_least_3,
3031
require_s3,
3132
require_tf,
3233
require_torch,
@@ -1692,6 +1693,7 @@ def test_to_pandas(self, in_memory):
16921693
for col_name in dset.column_names:
16931694
self.assertEqual(len(dset_to_pandas[col_name]), dset.num_rows)
16941695

1696+
@require_pyarrow_at_least_3
16951697
def test_to_parquet(self, in_memory):
16961698
with tempfile.TemporaryDirectory() as tmp_dir:
16971699
# File path argument
@@ -2677,6 +2679,7 @@ def _check_parquet_dataset(dataset, expected_features):
26772679
assert dataset.features[feature].dtype == expected_dtype
26782680

26792681

2682+
@require_pyarrow_at_least_3
26802683
@pytest.mark.parametrize("keep_in_memory", [False, True])
26812684
def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
26822685
cache_dir = tmp_path / "cache"
@@ -2686,6 +2689,7 @@ def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_p
26862689
_check_parquet_dataset(dataset, expected_features)
26872690

26882691

2692+
@require_pyarrow_at_least_3
26892693
@pytest.mark.parametrize(
26902694
"features",
26912695
[
@@ -2707,6 +2711,7 @@ def test_dataset_from_parquet_features(features, parquet_path, tmp_path):
27072711
_check_parquet_dataset(dataset, expected_features)
27082712

27092713

2714+
@require_pyarrow_at_least_3
27102715
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
27112716
def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
27122717
cache_dir = tmp_path / "cache"
@@ -2716,6 +2721,7 @@ def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
27162721
assert dataset.split == str(split) if split else "train"
27172722

27182723

2724+
@require_pyarrow_at_least_3
27192725
@pytest.mark.parametrize("path_type", [str, list])
27202726
def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path):
27212727
if issubclass(path_type, str):

tests/test_dataset_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from typing import List, Optional
2323
from unittest import TestCase
2424

25+
import pyarrow as pa
2526
from absl.testing import parameterized
27+
from packaging import version
2628

2729
from datasets import cached_path, import_main_class, load_dataset, prepare_module
2830
from datasets.builder import BuilderConfig, DatasetBuilder
@@ -270,7 +272,10 @@ def test_load_real_dataset_all_configs(self, dataset_name):
270272

271273

272274
def get_packaged_dataset_names():
273-
return [{"testcase_name": x, "dataset_name": x} for x in _PACKAGED_DATASETS_MODULES.keys()]
275+
packaged_datasets = [{"testcase_name": x, "dataset_name": x} for x in _PACKAGED_DATASETS_MODULES.keys()]
276+
if version.parse(pa.__version) < version.parse("3.0.0"): # parquet is not supported for pyarrow<3.0.0
277+
packaged_datasets = [pd for pd in packaged_datasets if pd["dataset_name"] != "parquet"]
278+
return packaged_datasets
274279

275280

276281
@parameterized.named_parameters(get_packaged_dataset_names())

tests/test_dataset_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .utils import (
1717
assert_arrow_memory_doesnt_increase,
1818
assert_arrow_memory_increases,
19+
require_pyarrow_at_least_3,
1920
require_s3,
2021
require_tf,
2122
require_torch,
@@ -597,6 +598,7 @@ def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",
597598
assert dataset.features[feature].dtype == expected_dtype
598599

599600

601+
@require_pyarrow_at_least_3
600602
@pytest.mark.parametrize("keep_in_memory", [False, True])
601603
def test_datasetdict_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
602604
cache_dir = tmp_path / "cache"
@@ -606,6 +608,7 @@ def test_datasetdict_from_parquet_keep_in_memory(keep_in_memory, parquet_path, t
606608
_check_parquet_datasetdict(dataset, expected_features)
607609

608610

611+
@require_pyarrow_at_least_3
609612
@pytest.mark.parametrize(
610613
"features",
611614
[
@@ -627,6 +630,7 @@ def test_datasetdict_from_parquet_features(features, parquet_path, tmp_path):
627630
_check_parquet_datasetdict(dataset, expected_features)
628631

629632

633+
@require_pyarrow_at_least_3
630634
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
631635
def test_datasetdict_from_parquet_split(split, parquet_path, tmp_path):
632636
if split:

tests/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import patch
99

1010
import pyarrow as pa
11+
from packaging import version
1112

1213
from datasets import config
1314

@@ -34,6 +35,19 @@ def parse_flag_from_env(key, default=False):
3435
_run_packaged_tests = parse_flag_from_env("RUN_PACKAGED", default=True)
3536

3637

38+
def require_pyarrow_at_least_3(test_case):
39+
"""
40+
Decorator marking a test that requires PyArrow 3.0.0
41+
to allow nested types in parquet, as well as batch iterators of parquet files.
42+
43+
These tests are skipped when the PyArrow version is outdated.
44+
45+
"""
46+
if version.parse(config.PYARROW_VERSION) < version.parse("3.0.0"):
47+
test_case = unittest.skip("test requires PyTorch")(test_case)
48+
return test_case
49+
50+
3751
def require_beam(test_case):
3852
"""
3953
Decorator marking a test that requires Apache Beam.

0 commit comments

Comments
 (0)