Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,63 @@

@dataclass
class ParquetConfig(datasets.BuilderConfig):
"""BuilderConfig for Parquet."""
"""
BuilderConfig for Parquet.

Args:
batch_size (`int`, *optional*):
Size of the RecordBatches to iterate on.
The default is the row group size (defined by the first row group).
columns (`list[str]`, *optional*)
List of columns to load, the other ones are ignored.
All columns are loaded by default.
features: (`Features`, *optional*):
Cast the data to `features`.
filters (`Union[pyarrow.dataset.Expression, list[tuple], list[list[tuple]]]`, *optional*):
Return only the rows matching the filter.
If possible the predicate will be pushed down to exploit the partition information
or internal metadata found in the data source, e.g. Parquet statistics.
Otherwise filters the loaded RecordBatches before yielding them.
fragment_scan_options (`pyarrow.dataset.ParquetFragmentScanOptions`)
Scan-specific options for Parquet fragments.
This is especially useful to configure buffering and caching.

Example:

Load a subset of columns:

```python
>>> ds = load_dataset(parquet_dataset_id, columns=["col_0", "col_1"])
```

Stream data and efficiently filter data, possibly skipping entire files or row groups:

```python
>>> filters = [("col_0", "==", 0)]
>>> ds = load_dataset(parquet_dataset_id, streaming=True, filters=filters)
```

Increase the minimum request size when streaming from 32MiB (default) to 128MiB and enable prefetching:

```python
>>> import pyarrow
>>> import pyarrow.dataset
>>> fragment_scan_options = pyarrow.dataset.ParquetFragmentScanOptions(
... cache_options=pyarrow.CacheOptions(
... prefetch_limit=1,
... range_size_limit=128 << 20
... ),
... )
>>> ds = load_dataset(parquet_dataset_id, streaming=True, fragment_scan_options=fragment_scan_options)
```

"""

batch_size: Optional[int] = None
columns: Optional[list[str]] = None
features: Optional[datasets.Features] = None
filters: Optional[Union[ds.Expression, list[tuple], list[list[tuple]]]] = None
fragment_scan_options: Optional[ds.ParquetFragmentScanOptions] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -84,9 +135,10 @@ def _generate_tables(self, files):
if isinstance(self.config.filters, list)
else self.config.filters
)
parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
parquet_fragment = ds.ParquetFileFormat().make_fragment(f)
parquet_fragment = parquet_file_format.make_fragment(f)
if parquet_fragment.row_groups:
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
try:
Expand Down
Loading