|
8 | 8 | from unittest.mock import patch |
9 | 9 |
|
10 | 10 | import numpy as np |
| 11 | +import pyarrow as pa |
| 12 | +import pyarrow.parquet as pq |
11 | 13 | import pytest |
12 | 14 | from multiprocess.pool import Pool |
13 | 15 |
|
14 | 16 | from datasets.arrow_dataset import Dataset |
15 | 17 | from datasets.arrow_writer import ArrowWriter |
16 | | -from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder |
| 18 | +from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder |
17 | 19 | from datasets.dataset_dict import DatasetDict, IterableDatasetDict |
18 | 20 | from datasets.download.download_manager import DownloadMode |
19 | 21 | from datasets.features import Features, Value |
20 | 22 | from datasets.info import DatasetInfo, PostProcessedInfo |
21 | 23 | from datasets.iterable_dataset import IterableDataset |
22 | 24 | from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo |
23 | 25 | from datasets.streaming import xjoin |
| 26 | +from datasets.utils.file_utils import is_local_path |
24 | 27 |
|
25 | | -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss |
| 28 | +from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss |
26 | 29 |
|
27 | 30 |
|
28 | 31 | class DummyBuilder(DatasetBuilder): |
@@ -57,6 +60,35 @@ def _generate_examples(self): |
57 | 60 | yield i, {"text": "foo"} |
58 | 61 |
|
59 | 62 |
|
| 63 | +class DummyArrowBasedBuilder(ArrowBasedBuilder): |
| 64 | + def _info(self): |
| 65 | + return DatasetInfo(features=Features({"text": Value("string")})) |
| 66 | + |
| 67 | + def _split_generators(self, dl_manager): |
| 68 | + return [SplitGenerator(name=Split.TRAIN)] |
| 69 | + |
| 70 | + def _generate_tables(self): |
| 71 | + for i in range(10): |
| 72 | + yield i, pa.table({"text": ["foo"] * 10}) |
| 73 | + |
| 74 | + |
| 75 | +class DummyBeamBasedBuilder(BeamBasedBuilder): |
| 76 | + def _info(self): |
| 77 | + return DatasetInfo(features=Features({"text": Value("string")})) |
| 78 | + |
| 79 | + def _split_generators(self, dl_manager): |
| 80 | + return [SplitGenerator(name=Split.TRAIN)] |
| 81 | + |
| 82 | + def _build_pcollection(self, pipeline): |
| 83 | + import apache_beam as beam |
| 84 | + |
| 85 | + def _process(item): |
| 86 | + for i in range(10): |
| 87 | + yield f"{i}_{item}", {"text": "foo"} |
| 88 | + |
| 89 | + return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process) |
| 90 | + |
| 91 | + |
60 | 92 | class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder): |
61 | 93 | def _info(self): |
62 | 94 | return DatasetInfo(features=Features({"id": Value("int8")})) |
@@ -690,6 +722,41 @@ def test_cache_dir_for_data_dir(self): |
690 | 722 | self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) |
691 | 723 |
|
692 | 724 |
|
| 725 | +def test_arrow_based_download_and_prepare(tmp_path): |
| 726 | + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) |
| 727 | + builder.download_and_prepare() |
| 728 | + assert os.path.exists( |
| 729 | + os.path.join( |
| 730 | + tmp_path, |
| 731 | + builder.name, |
| 732 | + "default", |
| 733 | + "0.0.0", |
| 734 | + f"{builder.name}-train.arrow", |
| 735 | + ) |
| 736 | + ) |
| 737 | + assert builder.info.features, Features({"text": Value("string")}) |
| 738 | + assert builder.info.splits["train"].num_examples, 100 |
| 739 | + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) |
| 740 | + |
| 741 | + |
| 742 | +@require_beam |
| 743 | +def test_beam_based_download_and_prepare(tmp_path): |
| 744 | + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") |
| 745 | + builder.download_and_prepare() |
| 746 | + assert os.path.exists( |
| 747 | + os.path.join( |
| 748 | + tmp_path, |
| 749 | + builder.name, |
| 750 | + "default", |
| 751 | + "0.0.0", |
| 752 | + f"{builder.name}-train.arrow", |
| 753 | + ) |
| 754 | + ) |
| 755 | + assert builder.info.features, Features({"text": Value("string")}) |
| 756 | + assert builder.info.splits["train"].num_examples, 100 |
| 757 | + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) |
| 758 | + |
| 759 | + |
693 | 760 | @pytest.mark.parametrize( |
694 | 761 | "split, expected_dataset_class, expected_dataset_length", |
695 | 762 | [ |
@@ -846,3 +913,62 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): |
846 | 913 | cache_dir = str(tmp_path) |
847 | 914 | builder = builder_class(cache_dir=cache_dir, **kwargs) |
848 | 915 | assert builder.config.version == "2.0.0" |
| 916 | + |
| 917 | + |
| 918 | +def test_builder_with_filesystem(mockfs): |
| 919 | + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) |
| 920 | + assert builder.cache_dir.startswith("mock://") |
| 921 | + assert is_local_path(builder._cache_downloaded_dir) |
| 922 | + assert isinstance(builder._fs, type(mockfs)) |
| 923 | + assert builder._fs.storage_options == mockfs.storage_options |
| 924 | + |
| 925 | + |
| 926 | +def test_builder_with_filesystem_download_and_prepare(mockfs): |
| 927 | + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) |
| 928 | + builder.download_and_prepare() |
| 929 | + assert mockfs.exists(f"{builder.name}/default/0.0.0/dataset_info.json") |
| 930 | + assert mockfs.exists(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") |
| 931 | + assert not mockfs.exists(f"{builder.name}/default/0.0.0.incomplete") |
| 932 | + |
| 933 | + |
| 934 | +def test_builder_with_filesystem_download_and_prepare_reload(mockfs, caplog): |
| 935 | + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) |
| 936 | + mockfs.makedirs(f"{builder.name}/default/0.0.0") |
| 937 | + DatasetInfo().write_to_directory(f"{builder.name}/default/0.0.0", fs=mockfs) |
| 938 | + mockfs.touch(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") |
| 939 | + caplog.clear() |
| 940 | + builder.download_and_prepare() |
| 941 | + assert "Found cached dataset" in caplog.text |
| 942 | + |
| 943 | + |
| 944 | +def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): |
| 945 | + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) |
| 946 | + builder.download_and_prepare(file_format="parquet") |
| 947 | + assert builder.info.splits["train"].num_examples, 100 |
| 948 | + parquet_path = os.path.join( |
| 949 | + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" |
| 950 | + ) |
| 951 | + assert os.path.exists(parquet_path) |
| 952 | + assert pq.ParquetFile(parquet_path) is not None |
| 953 | + |
| 954 | + |
| 955 | +def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): |
| 956 | + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) |
| 957 | + builder.download_and_prepare(file_format="parquet") |
| 958 | + assert builder.info.splits["train"].num_examples, 100 |
| 959 | + parquet_path = os.path.join( |
| 960 | + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" |
| 961 | + ) |
| 962 | + assert os.path.exists(parquet_path) |
| 963 | + assert pq.ParquetFile(parquet_path) is not None |
| 964 | + |
| 965 | + |
| 966 | +def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): |
| 967 | + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") |
| 968 | + builder.download_and_prepare(file_format="parquet") |
| 969 | + assert builder.info.splits["train"].num_examples, 100 |
| 970 | + parquet_path = os.path.join( |
| 971 | + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" |
| 972 | + ) |
| 973 | + assert os.path.exists(parquet_path) |
| 974 | + assert pq.ParquetFile(parquet_path) is not None |
0 commit comments