diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 10797753657..9921a2be6b9 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -15,12 +15,63 @@ @dataclass class ParquetConfig(datasets.BuilderConfig): - """BuilderConfig for Parquet.""" + """ + BuilderConfig for Parquet. + + Args: + batch_size (`int`, *optional*): + Size of the RecordBatches to iterate on. + The default is the row group size (defined by the first row group). + columns (`list[str]`, *optional*) + List of columns to load, the other ones are ignored. + All columns are loaded by default. + features: (`Features`, *optional*): + Cast the data to `features`. + filters (`Union[pyarrow.dataset.Expression, list[tuple], list[list[tuple]]]`, *optional*): + Return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the partition information + or internal metadata found in the data source, e.g. Parquet statistics. + Otherwise filters the loaded RecordBatches before yielding them. + fragment_scan_options (`pyarrow.dataset.ParquetFragmentScanOptions`) + Scan-specific options for Parquet fragments. + This is especially useful to configure buffering and caching. + + Example: + + Load a subset of columns: + + ```python + >>> ds = load_dataset(parquet_dataset_id, columns=["col_0", "col_1"]) + ``` + + Stream data and efficiently filter data, possibly skipping entire files or row groups: + + ```python + >>> filters = [("col_0", "==", 0)] + >>> ds = load_dataset(parquet_dataset_id, streaming=True, filters=filters) + ``` + + Increase the minimum request size when streaming from 32MiB (default) to 128MiB and enable prefetching: + + ```python + >>> import pyarrow + >>> import pyarrow.dataset + >>> fragment_scan_options = pyarrow.dataset.ParquetFragmentScanOptions( + ... cache_options=pyarrow.CacheOptions( + ... prefetch_limit=1, + ... range_size_limit=128 << 20 + ... ), + ... ) + >>> ds = load_dataset(parquet_dataset_id, streaming=True, fragment_scan_options=fragment_scan_options) + ``` + + """ batch_size: Optional[int] = None columns: Optional[list[str]] = None features: Optional[datasets.Features] = None filters: Optional[Union[ds.Expression, list[tuple], list[list[tuple]]]] = None + fragment_scan_options: Optional[ds.ParquetFragmentScanOptions] = None def __post_init__(self): super().__post_init__() @@ -84,9 +135,10 @@ def _generate_tables(self, files): if isinstance(self.config.filters, list) else self.config.filters ) + parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options) for file_idx, file in enumerate(itertools.chain.from_iterable(files)): with open(file, "rb") as f: - parquet_fragment = ds.ParquetFileFormat().make_fragment(f) + parquet_fragment = parquet_file_format.make_fragment(f) if parquet_fragment.row_groups: batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows try: