Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b7571ab
Fix `array.values` handling in array cast/embed
mariosasko Oct 5, 2023
feb1c1a
Fix fixed size array with nulls cast
mariosasko Oct 6, 2023
0b2ad10
Bump PyArrow to version 12.0.0
mariosasko Dec 5, 2023
fda0d31
Fix cast/embed
mariosasko Dec 5, 2023
f7b48ba
Resolve merge conflicts
mariosasko Dec 5, 2023
024c029
Remove pdb comment
mariosasko Dec 5, 2023
09b5e15
Add warnings and some comments
mariosasko Dec 19, 2023
19ee42b
CI fix
mariosasko Dec 19, 2023
1505ce6
Onemore comment
mariosasko Dec 19, 2023
5c8aa27
Merge branch 'main' of github.com:huggingface/datasets into fix-array…
mariosasko Dec 19, 2023
da085d8
Don't install beam
mariosasko Dec 21, 2023
2aec0f7
Fix tests
mariosasko Dec 21, 2023
12c4c57
Still run beam tests?
mariosasko Dec 21, 2023
00e7856
Revert "Still run beam tests?"
mariosasko Dec 21, 2023
9a694d8
Nit
mariosasko Jan 23, 2024
c167caa
Merge branch 'main' of github.com:huggingface/datasets into fix-array…
mariosasko Jan 23, 2024
4828edf
Cleaner implementation
mariosasko Jan 26, 2024
087140e
Cleaner impl part 2
mariosasko Jan 31, 2024
68faedb
Resolve conflict
mariosasko Jan 31, 2024
2881a1a
Nit
mariosasko Jan 31, 2024
86c8ac2
Fix CI
mariosasko Jan 31, 2024
2a211e1
Merge branch 'main' of github.com:huggingface/datasets into fix-array…
mariosasko Jan 31, 2024
79ee0df
Optimization
mariosasko Feb 1, 2024
d088db4
Nit
mariosasko Feb 1, 2024
c9343c0
Nit
mariosasko Feb 4, 2024
3a68113
Update src/datasets/table.py
mariosasko Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
run: pip install --upgrade pyarrow huggingface-hub dill
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: pip install pyarrow==8.0.0 huggingface-hub==0.19.4 transformers dill==0.3.1.1
run: pip install pyarrow==12.0.0 huggingface-hub==0.19.4 transformers dill==0.3.1.1
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17",
# Backend and serialization.
# Minimum 8.0.0 to be able to use .to_reader()
"pyarrow>=8.0.0",
# Minimum 12.0.0 to be able to concatenate extension arrays
"pyarrow>=12.0.0",
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
"pyarrow-hotfix",
# For smart caching dataset processing
Expand Down Expand Up @@ -166,7 +166,7 @@
"pytest-datadir",
"pytest-xdist",
# optional dependencies
"apache-beam>=2.26.0,<2.44.0;python_version<'3.10'", # doesn't support recent dill versions for recent python versions
"apache-beam>=2.26.0; sys_platform != 'win32' and python_version<'3.10'", # doesn't support recent dill versions for recent python versions and on windows requires pyarrow<12.0.0
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"jax>=0.3.14; sys_platform != 'win32'",
Expand Down Expand Up @@ -233,7 +233,7 @@
EXTRAS_REQUIRE = {
"audio": AUDIO_REQUIRE,
"vision": VISION_REQUIRE,
"apache-beam": ["apache-beam>=2.26.0,<2.44.0"],
"apache-beam": ["apache-beam>=2.26.0"],
"tensorflow": [
"tensorflow>=2.2.0,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
Expand Down
36 changes: 16 additions & 20 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .filesystems import is_remote_filesystem
from .info import DatasetInfo
from .keyhash import DuplicatedKeysError, KeyHasher
from .table import array_cast, array_concat, cast_array_to_feature, embed_table_storage, table_cast
from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import hash_url_to_filename
Expand Down Expand Up @@ -441,7 +441,12 @@ def write_examples_on_file(self):
# This can happen in `.map()` when we want to re-write the same Arrow data
if all(isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) for row in self.current_examples):
arrays = [row[0][col] for row in self.current_examples]
batch_examples[col] = array_concat(arrays)
arrays = [
chunk
for array in arrays
for chunk in (array.chunks if isinstance(array, pa.ChunkedArray) else [array])
]
batch_examples[col] = pa.concat_arrays(arrays)
else:
batch_examples[col] = [
row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
Expand Down Expand Up @@ -669,33 +674,23 @@ def finalize(self, metrics_query_result: dict):
metrics_query_result: `dict` obtained from pipeline_results.metrics().query(m_filter). Make sure
that the filter keeps only the metrics for the considered split, under the namespace `split_name`.
"""
import apache_beam as beam

from .utils import beam_utils

# Beam FileSystems require the system's path separator in the older versions
fs, _, [parquet_path] = fsspec.get_fs_token_paths(self._parquet_path)
parquet_path = str(Path(parquet_path)) if not is_remote_filesystem(fs) else fs.unstrip_protocol(parquet_path)

shards_metadata = list(beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list)
shards = [metadata.path for metadata in shards_metadata]
num_bytes = sum([metadata.size_in_bytes for metadata in shards_metadata])
shards = fs.glob(parquet_path + "*.parquet")
num_bytes = sum(fs.sizes(shards))
shard_lengths = get_parquet_lengths(shards)

# Convert to arrow
if self._path.endswith(".arrow"):
logger.info(f"Converting parquet files {self._parquet_path} to arrow {self._path}")
shards = [
metadata.path
for metadata in beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list
]
try: # stream conversion
num_bytes = 0
for shard in hf_tqdm(shards, unit="shards"):
with beam.io.filesystems.FileSystems.open(shard) as source:
with beam.io.filesystems.FileSystems.create(
shard.replace(".parquet", ".arrow")
) as destination:
with fs.open(shard, "rb") as source:
with fs.open(shard.replace(".parquet", ".arrow"), "wb") as destination:
shard_num_bytes, _ = parquet_to_arrow(source, destination)
num_bytes += shard_num_bytes
except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead
Expand All @@ -709,12 +704,12 @@ def finalize(self, metrics_query_result: dict):
num_bytes = 0
for shard in hf_tqdm(shards, unit="shards"):
local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
beam_utils.download_remote_to_local(shard, local_parquet_path)
fs.download(shard, local_parquet_path)
local_arrow_path = local_parquet_path.replace(".parquet", ".arrow")
shard_num_bytes, _ = parquet_to_arrow(local_parquet_path, local_arrow_path)
num_bytes += shard_num_bytes
remote_arrow_path = shard.replace(".parquet", ".arrow")
beam_utils.upload_local_to_remote(local_arrow_path, remote_arrow_path)
fs.upload(local_arrow_path, remote_arrow_path)

# Save metrics
counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]}
Expand All @@ -735,8 +730,9 @@ def get_parquet_lengths(sources) -> List[int]:
def parquet_to_arrow(source, destination) -> List[int]:
"""Convert parquet file to arrow file. Inputs can be str paths or file-like objects"""
stream = None if isinstance(destination, str) else destination
with ArrowWriter(path=destination, stream=stream) as writer:
parquet_file = pa.parquet.ParquetFile(source)
parquet_file = pa.parquet.ParquetFile(source)
# Beam can create empty Parquet files, so we need to pass the source Parquet file's schema
with ArrowWriter(schema=parquet_file.schema_arrow, path=destination, stream=stream) as writer:
for record_batch in parquet_file.iter_batches():
pa_table = pa.Table.from_batches([record_batch])
writer.write_table(pa_table)
Expand Down
Loading