Skip to content

Commit b288977

Browse files
kmehantlhoestq
authored andcommitted
feat: support non streamable arrow file binary format (#7025)
* feat: support non streamable arrow file binary format Signed-off-by: Mehant Kammakomati <[email protected]> * use generator Co-authored-by: Quentin Lhoest <[email protected]> * feat: add unit test to load data in both arrow formats Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 724e767 commit b288977

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/datasets/packaged_modules/arrow/arrow.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def _split_generators(self, dl_manager):
4343
if self.info.features is None:
4444
for file in itertools.chain.from_iterable(files):
4545
with open(file, "rb") as f:
46-
self.info.features = datasets.Features.from_arrow_schema(pa.ipc.open_stream(f).schema)
46+
try:
47+
reader = pa.ipc.open_stream(f)
48+
except pa.lib.ArrowInvalid:
49+
reader = pa.ipc.open_file(f)
50+
self.info.features = datasets.Features.from_arrow_schema(reader.schema)
4751
break
4852
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
4953
return splits
@@ -59,7 +63,12 @@ def _generate_tables(self, files):
5963
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
6064
with open(file, "rb") as f:
6165
try:
62-
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
66+
try:
67+
batches = pa.ipc.open_stream(f)
68+
except pa.lib.ArrowInvalid:
69+
reader = pa.ipc.open_file(f)
70+
batches = (reader.get_batch(i) for i in range(reader.num_record_batches))
71+
for batch_idx, record_batch in enumerate(batches):
6372
pa_table = pa.Table.from_batches([record_batch])
6473
# Uncomment for debugging (will print the Arrow table size and elements)
6574
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")

tests/packaged_modules/test_arrow.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,53 @@
1+
import pyarrow as pa
12
import pytest
23

34
from datasets.builder import InvalidConfigName
45
from datasets.data_files import DataFilesList
5-
from datasets.packaged_modules.arrow.arrow import ArrowConfig
6+
from datasets.packaged_modules.arrow.arrow import Arrow, ArrowConfig
7+
8+
9+
@pytest.fixture
10+
def arrow_file_streaming_format(tmp_path):
11+
filename = tmp_path / "stream.arrow"
12+
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]
13+
14+
schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
15+
array = pa.array(testdata, type=pa.list_(pa.int32()))
16+
table = pa.Table.from_arrays([array], schema=schema)
17+
with open(filename, "wb") as f:
18+
with pa.ipc.new_stream(f, schema) as writer:
19+
writer.write_table(table)
20+
return str(filename)
21+
22+
23+
@pytest.fixture
24+
def arrow_file_file_format(tmp_path):
25+
filename = tmp_path / "file.arrow"
26+
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]
27+
28+
schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
29+
array = pa.array(testdata, type=pa.list_(pa.int32()))
30+
table = pa.Table.from_arrays([array], schema=schema)
31+
with open(filename, "wb") as f:
32+
with pa.ipc.new_file(f, schema) as writer:
33+
writer.write_table(table)
34+
return str(filename)
35+
36+
37+
@pytest.mark.parametrize(
38+
"file_fixture, config_kwargs",
39+
[
40+
("arrow_file_streaming_format", {}),
41+
("arrow_file_file_format", {}),
42+
],
43+
)
44+
def test_arrow_generate_tables(file_fixture, config_kwargs, request):
45+
arrow = Arrow(**config_kwargs)
46+
generator = arrow._generate_tables([[request.getfixturevalue(file_fixture)]])
47+
pa_table = pa.concat_tables([table for _, table in generator])
48+
49+
expected = {"input_ids": [[1, 1, 1], [0, 100, 6], [1, 90, 900]]}
50+
assert pa_table.to_pydict() == expected
651

752

853
def test_config_raises_when_invalid_name() -> None:

0 commit comments

Comments
 (0)