Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -138,14 +139,14 @@ def _split_generators(self, dl_manager):
files = data_files
if isinstance(files, str):
files = [files]
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
Expand All @@ -170,7 +171,7 @@ def _generate_tables(self, files):
if schema is not None
else None
)
for file_idx, file in enumerate(files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
try:
for batch_idx, df in enumerate(csv_file_reader):
Expand Down
11 changes: 6 additions & 5 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import itertools
import json
from dataclasses import dataclass
from typing import Optional
Expand Down Expand Up @@ -50,14 +51,14 @@ def _split_generators(self, dl_manager):
files = data_files
if isinstance(files, str):
files = [files]
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
Expand All @@ -68,7 +69,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
return pa_table

def _generate_tables(self, files):
for file_idx, file in enumerate(files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):

# If the file is one json object and if we need to look at the list of items in one specific field
if self.config.field is not None:
Expand Down
11 changes: 6 additions & 5 deletions src/datasets/packaged_modules/pandas/pandas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from dataclasses import dataclass
from typing import Optional

Expand Down Expand Up @@ -31,15 +32,15 @@ def _split_generators(self, dl_manager):
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
Expand All @@ -50,7 +51,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
return pa_table

def _generate_tables(self, files):
for i, file in enumerate(files):
for i, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
pa_table = pa.Table.from_pandas(pd.read_pickle(f))
yield i, self._cast_table(pa_table)
11 changes: 6 additions & 5 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from dataclasses import dataclass
from typing import List, Optional

Expand Down Expand Up @@ -36,15 +37,15 @@ def _split_generators(self, dl_manager):
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
Expand All @@ -61,7 +62,7 @@ def _generate_tables(self, files):
raise ValueError(
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.config.features}'"
)
for file_idx, file in enumerate(files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
parquet_file = pq.ParquetFile(f)
try:
Expand Down
11 changes: 6 additions & 5 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from dataclasses import dataclass
from io import StringIO
from typing import Optional
Expand Down Expand Up @@ -42,14 +43,14 @@ def _split_generators(self, dl_manager):
files = data_files
if isinstance(files, str):
files = [files]
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
Expand All @@ -67,7 +68,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:

def _generate_tables(self, files):
pa_table_names = list(self.config.features) if self.config.features is not None else ["text"]
for file_idx, file in enumerate(files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n"
with open(file, encoding=self.config.encoding) as f:
if self.config.sample_by == "line":
Expand Down
6 changes: 3 additions & 3 deletions tests/packaged_modules/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def csv_file_with_label(tmp_path):

def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog):
csv = Csv()
generator = csv._generate_tables([csv_file, malformed_csv_file])
generator = csv._generate_tables([[csv_file, malformed_csv_file]])
with pytest.raises(ValueError, match="Error tokenizing data"):
for _ in generator:
pass
Expand All @@ -89,7 +89,7 @@ def test_csv_cast_image(csv_file_with_image):
with open(csv_file_with_image, encoding="utf-8") as f:
image_file = f.read().splitlines()[1]
csv = Csv(encoding="utf-8", features=Features({"image": Image()}))
generator = csv._generate_tables([csv_file_with_image])
generator = csv._generate_tables([[csv_file_with_image]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.schema.field("image").type == Image()()
generated_content = pa_table.to_pydict()["image"]
Expand All @@ -101,7 +101,7 @@ def test_csv_cast_label(csv_file_with_label):
with open(csv_file_with_label, encoding="utf-8") as f:
labels = f.read().splitlines()[1:]
csv = Csv(encoding="utf-8", features=Features({"label": ClassLabel(names=["good", "bad"])}))
generator = csv._generate_tables([csv_file_with_label])
generator = csv._generate_tables([[csv_file_with_label]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])()
generated_content = pa_table.to_pydict()["label"]
Expand Down
4 changes: 2 additions & 2 deletions tests/packaged_modules/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_text_linebreaks(text_file, keep_linebreaks):
with open(text_file, encoding="utf-8") as f:
expected_content = f.read().splitlines(keepends=keep_linebreaks)
text = Text(keep_linebreaks=keep_linebreaks, encoding="utf-8")
generator = text._generate_tables([text_file])
generator = text._generate_tables([[text_file]])
generated_content = pa.concat_tables([table for _, table in generator]).to_pydict()["text"]
assert generated_content == expected_content

Expand All @@ -48,7 +48,7 @@ def test_text_cast_image(text_file_with_image):
with open(text_file_with_image, encoding="utf-8") as f:
image_file = f.read().splitlines()[0]
text = Text(encoding="utf-8", features=Features({"image": Image()}))
generator = text._generate_tables([text_file_with_image])
generator = text._generate_tables([[text_file_with_image]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.schema.field("image").type == Image()()
generated_content = pa_table.to_pydict()["image"]
Expand Down