Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ jobs:
pip install .[quality]
- name: Check quality
run: |
black --check tests src benchmarks metrics
ruff tests src benchmarks metrics
ruff check tests src benchmarks metrics utils setup.py # linter
ruff format --check tests src benchmarks metrics utils setup.py # formatter

test:
needs: check_code_quality
Expand Down
20 changes: 7 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
- repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage
rev: 'v0.1.5'
hooks:
- id: black
language_version: python3
types: [python]
stages: [commit]
args: ["--config", "pyproject.toml", "tests", "src", "benchmarks", "metrics"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
stages: [commit]
args: [ "--config", "pyproject.toml", "tests", "src", "benchmarks", "metrics", "--fix"]
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ check_dirs := tests src benchmarks metrics utils
# Check that source code meets quality standards

quality:
black --check $(check_dirs) setup.py
ruff $(check_dirs) setup.py
ruff check $(check_dirs) setup.py # linter
ruff format --check $(check_dirs) setup.py # formatter

# Format source code automatically

style:
black tests src benchmarks metrics setup.py
ruff $(check_dirs) setup.py --fix
ruff check --fix $(check_dirs) setup.py # linter
ruff format $(check_dirs) setup.py # formatter

# Run tests for the library

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
TESTS_REQUIRE.extend(VISION_REQUIRE)
TESTS_REQUIRE.extend(AUDIO_REQUIRE)

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241", "pyyaml>=5.3.1"]
QUALITY_REQUIRE = ["ruff>=0.1.5"]

DOCS_REQUIRE = [
# Might need to add doc-builder and some specific deps in the future
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,7 +3101,8 @@ def load_processed_shard_from_cache(shard_kwargs):
else:

def format_cache_file_name(
cache_file_name: Optional[str], rank: Union[int, Literal["*"]] # noqa: F722
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name
Expand Down
10 changes: 6 additions & 4 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def _convert_to_arrow(
Drop the last batch if it is smaller than `batch_size`.
"""
if batch_size is None or batch_size <= 0:
yield "all", pa.Table.from_pylist(
cast_to_python_objects([example for _, example in iterable], only_1d_for_numpy=True)
yield (
"all",
pa.Table.from_pylist(cast_to_python_objects([example for _, example in iterable], only_1d_for_numpy=True)),
)
return
iterator = iter(iterable)
Expand Down Expand Up @@ -1112,8 +1113,9 @@ def __iter__(self):
# Then for each example, `TypedExamplesIterable` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
for key, example in self.ex_iterable:
yield key, _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self.token_per_repo_id
yield (
key,
_apply_feature_types_on_example(example, self.features, token_per_repo_id=self.token_per_repo_id),
)

def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]:
Expand Down
8 changes: 2 additions & 6 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,9 +1493,7 @@ def dataset_module_factory(
download_config=download_config,
download_mode=download_mode,
).get_module()
except (
Exception
) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
except Exception as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
except Exception: # noqa if it's not in the cache, then it doesn't exist.
Expand Down Expand Up @@ -1598,9 +1596,7 @@ def metric_module_factory(
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
except (
Exception
) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
except Exception as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedMetricModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
except Exception: # noqa if it's not in the cache, then it doesn't exist.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,15 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_label = {"label": os.path.basename(os.path.dirname(original_file))}
else:
sample_label = {}
yield file_idx, {
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_file_or_dir,
**sample_metadata,
**sample_label,
}
yield (
file_idx,
{
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_file_or_dir,
**sample_metadata,
**sample_label,
},
)
file_idx += 1
else:
for downloaded_dir_file in downloaded_file_or_dir:
Expand Down Expand Up @@ -391,10 +394,13 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_label = {"label": os.path.basename(os.path.dirname(downloaded_dir_file))}
else:
sample_label = {}
yield file_idx, {
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_dir_file,
**sample_metadata,
**sample_label,
}
yield (
file_idx,
{
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_dir_file,
**sample_metadata,
**sample_label,
},
)
file_idx += 1
3 changes: 3 additions & 0 deletions src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class SplitBase(metaclass=abc.ABCMeta):
to define which files to read and how to skip examples within file.

"""

# pylint: enable=line-too-long

@abc.abstractmethod
Expand Down Expand Up @@ -265,6 +266,7 @@ class PercentSlice(metaclass=PercentSliceMeta):
[guide on splits](../loading#slice-splits)
for more information.
"""

# pylint: enable=line-too-long
pass

Expand Down Expand Up @@ -438,6 +440,7 @@ class Split:
... )
```
"""

# pylint: enable=line-too-long
TRAIN = NamedSplit("train")
TEST = NamedSplit("test")
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __enter__(self):
# We don't check for the name of the global, but rather if its value *is* "os" or "os.path".
# This allows to patch renamed modules like "from os import path as ospath".
if obj_attr is submodule or (
(isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule)
isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule
):
self.original[attr] = obj_attr
# patch at top level
Expand Down
4 changes: 1 addition & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,9 +3066,7 @@ def test_concatenate_mixed_memory_and_disk(self):
cache_file_name=os.path.join(tmp_dir, "d1.arrow")
) as dset1, Dataset.from_dict(data2, info=info2).map(
cache_file_name=os.path.join(tmp_dir, "d2.arrow")
) as dset2, Dataset.from_dict(
data3
) as dset3:
) as dset2, Dataset.from_dict(data3) as dset3:
with concatenate_datasets([dset1, dset2, dset3]) as concatenated_dset:
self.assertEqual(len(concatenated_dset), len(dset1) + len(dset2) + len(dset3))
self.assertListEqual(concatenated_dset["id"], dset1["id"] + dset2["id"] + dset3["id"])
Expand Down
1 change: 1 addition & 0 deletions tests/test_readme_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# @pytest.fixture
# def example_yaml_structure():


example_yaml_structure = yaml.safe_load(
"""\
name: ""
Expand Down