Skip to content

Commit a20194f

Browse files
committed
Unpack dl_manager.iter_files to allow parallization
1 parent bbe338d commit a20194f

File tree

5 files changed

+20
-20
lines changed

5 files changed

+20
-20
lines changed

src/datasets/packaged_modules/csv/csv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def _split_generators(self, dl_manager):
138138
files = data_files
139139
if isinstance(files, str):
140140
files = [files]
141-
return [
142-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
143-
]
141+
files = [dl_manager.iter_files(file) for file in files]
142+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
144143
splits = []
145144
for split_name, files in data_files.items():
146145
if isinstance(files, str):
147146
files = [files]
148-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
147+
files = [dl_manager.iter_files(file) for file in files]
148+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
149149
return splits
150150

151151
def _cast_table(self, pa_table: pa.Table) -> pa.Table:

src/datasets/packaged_modules/json/json.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def _split_generators(self, dl_manager):
5050
files = data_files
5151
if isinstance(files, str):
5252
files = [files]
53-
return [
54-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
55-
]
53+
files = [dl_manager.iter_files(file) for file in files]
54+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
5655
splits = []
5756
for split_name, files in data_files.items():
5857
if isinstance(files, str):
5958
files = [files]
60-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
59+
files = [dl_manager.iter_files(file) for file in files]
60+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
6161
return splits
6262

6363
def _cast_table(self, pa_table: pa.Table) -> pa.Table:

src/datasets/packaged_modules/pandas/pandas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def _split_generators(self, dl_manager):
3131
if isinstance(files, str):
3232
files = [files]
3333
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
34-
return [
35-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
36-
]
34+
files = [dl_manager.iter_files(file) for file in files]
35+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
3736
splits = []
3837
for split_name, files in data_files.items():
3938
if isinstance(files, str):
4039
files = [files]
4140
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
42-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
41+
files = [dl_manager.iter_files(file) for file in files]
42+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
4343
return splits
4444

4545
def _cast_table(self, pa_table: pa.Table) -> pa.Table:

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def _split_generators(self, dl_manager):
3636
if isinstance(files, str):
3737
files = [files]
3838
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
39-
return [
40-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
41-
]
39+
files = [dl_manager.iter_files(file) for file in files]
40+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
4241
splits = []
4342
for split_name, files in data_files.items():
4443
if isinstance(files, str):
4544
files = [files]
4645
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
47-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
46+
files = [dl_manager.iter_files(file) for file in files]
47+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
4848
return splits
4949

5050
def _cast_table(self, pa_table: pa.Table) -> pa.Table:

src/datasets/packaged_modules/text/text.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def _split_generators(self, dl_manager):
4242
files = data_files
4343
if isinstance(files, str):
4444
files = [files]
45-
return [
46-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_files(files)})
47-
]
45+
files = [dl_manager.iter_files(file) for file in files]
46+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
4847
splits = []
4948
for split_name, files in data_files.items():
5049
if isinstance(files, str):
5150
files = [files]
52-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
51+
files = [dl_manager.iter_files(file) for file in files]
52+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
5353
return splits
5454

5555
def _cast_table(self, pa_table: pa.Table) -> pa.Table:

0 commit comments

Comments
 (0)