Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np

import datasets
from datasets.features import _ArrayXD
from datasets.arrow_writer import ArrowWriter
from datasets.features.features import _ArrayXD


def get_duration(func):
Expand Down Expand Up @@ -46,7 +47,7 @@ def generate_examples(features: dict, num_examples=100, seq_shapes=None):
def generate_example_dataset(dataset_path, features, num_examples=100, seq_shapes=None):
dummy_data = generate_examples(features, num_examples=num_examples, seq_shapes=seq_shapes)

with datasets.ArrowWriter(features=features, path=dataset_path) as writer:
with ArrowWriter(features=features, path=dataset_path) as writer:
for key, record in dummy_data:
example = features.encode_example(record)
writer.write(example)
Expand Down
4 changes: 2 additions & 2 deletions metrics/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

import datasets
from datasets.utils import tqdm
from datasets import utils


_CITATION = """\
Expand Down Expand Up @@ -113,7 +113,7 @@ def _compute(self, input_texts, model_id, stride=512, device=None):

ppls = []

for text_index in tqdm(range(0, len(encoded_texts))):
for text_index in utils.tqdm_utils.tqdm(range(0, len(encoded_texts))):
encoded_text = encoded_texts[text_index]
special_tokens_mask = special_tokens_masks[text_index]

Expand Down
35 changes: 11 additions & 24 deletions src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,26 @@
__version__ = "1.18.5.dev0"

import pyarrow
from packaging import version as _version
from pyarrow import total_allocated_bytes
from packaging import version


if _version.parse(pyarrow.__version__).major < 5:
if version.parse(pyarrow.__version__).major < 5:
raise ImportWarning(
"To use `datasets`, the module `pyarrow>=5.0.0` is required, and the current version of `pyarrow` doesn't match this condition.\n"
"If you are running this in a Google Colab, you should probably just restart the runtime to use the right version of `pyarrow`."
)

SCRIPTS_VERSION = "master" if version.parse(__version__).is_devrelease else __version__

del pyarrow
del version

from .arrow_dataset import Dataset, concatenate_datasets
from .arrow_reader import ArrowReader, ReadInstruction
from .arrow_writer import ArrowWriter
Comment on lines -34 to -35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here ArrowWriter and ArrowReader are removed from the top level module. I've seen that the lxmert example in transformers needs datasets.ArrowWriter, so I'm not sure if we should remove them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are removed because I see them as our internals. I'll open a PR in Transformers to update their paths (used here and here).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok :)

from .arrow_reader import ReadInstruction
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .combine import interleave_datasets
from .dataset_dict import DatasetDict, IterableDatasetDict
from .features import (
Array2D,
Array3D,
Array4D,
Array5D,
Audio,
ClassLabel,
Features,
Image,
Sequence,
Translation,
TranslationVariableLanguages,
Value,
)
from .features import *
from .fingerprint import is_caching_enabled, set_caching_enabled
from .info import DatasetInfo, MetricInfo
from .inspect import (
Expand All @@ -63,8 +53,7 @@
list_metrics,
)
from .iterable_dataset import IterableDataset
from .keyhash import KeyHasher
from .load import import_main_class, load_dataset, load_dataset_builder, load_from_disk, load_metric
from .load import load_dataset, load_dataset_builder, load_from_disk, load_metric
from .metric import Metric
from .splits import (
NamedSplit,
Expand All @@ -79,6 +68,4 @@
)
from .tasks import *
from .utils import *


SCRIPTS_VERSION = "master" if _version.parse(__version__).is_devrelease else __version__
from .utils import logging
21 changes: 5 additions & 16 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,8 @@
from . import config, utils
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .features import (
Audio,
ClassLabel,
Features,
FeatureType,
Image,
Sequence,
Value,
_ArrayXD,
decode_nested_example,
pandas_types_mapper,
require_decoding,
)
from .features import Audio, ClassLabel, Features, Image, Sequence, Value
from .features.features import FeatureType, _ArrayXD, decode_nested_example, pandas_types_mapper, require_decoding
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .fingerprint import (
fingerprint_transform,
Expand Down Expand Up @@ -2316,7 +2305,7 @@ def init_buffer_and_writer():
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
pbar_unit = "ex" if not batched else "ba"
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
pbar = utils.tqdm(
pbar = utils.tqdm_utils.tqdm(
pbar_iterable,
total=pbar_total,
disable=disable_tqdm,
Expand Down Expand Up @@ -3465,7 +3454,7 @@ def delete_file(file):
api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch)

if len(file_shards_to_delete):
for file in utils.tqdm(
for file in utils.tqdm_utils.tqdm(
file_shards_to_delete,
desc="Deleting unused files from dataset repository",
total=len(file_shards_to_delete),
Expand All @@ -3474,7 +3463,7 @@ def delete_file(file):
delete_file(file)

uploaded_size = 0
for index, shard in utils.tqdm(
for index, shard in utils.tqdm_utils.tqdm(
enumerate(shards),
desc="Pushing dataset shards to the dataset hub",
total=num_shards,
Expand Down
5 changes: 2 additions & 3 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import pyarrow as pa
import pyarrow.parquet as pq

from datasets.utils.file_utils import DownloadConfig

from .naming import _split_re, filename_for_dataset_split
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables
from .utils import cached_path, logging
from .utils import logging
from .utils.file_utils import DownloadConfig, cached_path


if TYPE_CHECKING:
Expand Down
12 changes: 5 additions & 7 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
import numpy as np
import pyarrow as pa

from datasets.features.features import FeatureType, Value

from . import config, utils
from .features import (
Features,
Image,
from .features import Features, Image, Value
from .features.features import (
FeatureType,
_ArrayXDExtensionType,
cast_to_python_objects,
generate_from_arrow_type,
Expand Down Expand Up @@ -641,9 +639,9 @@ def parquet_to_arrow(sources, destination):
stream = None if isinstance(destination, str) else destination
disable = not utils.is_progress_bar_enabled()
with ArrowWriter(path=destination, stream=stream) as writer:
for source in utils.tqdm(sources, unit="sources", disable=disable):
for source in utils.tqdm_utils.tqdm(sources, unit="sources", disable=disable):
pf = pa.parquet.ParquetFile(source)
for i in utils.tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=disable):
for i in utils.tqdm_utils.tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=disable):
df = pf.read_row_group(i).to_pandas()
for col in df.columns:
df[col] = df[col].apply(json.loads)
Expand Down
41 changes: 24 additions & 17 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
from functools import partial
from typing import Dict, Mapping, Optional, Tuple, Union

from datasets.features import Features
from datasets.utils.mock_download_manager import MockDownloadManager

from . import config, utils
from .arrow_dataset import Dataset
from .arrow_reader import (
Expand All @@ -42,16 +39,26 @@
from .arrow_writer import ArrowWriter, BeamWriter
from .data_files import DataFilesDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .features import Features
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper
from .naming import camelcase_to_snakecase, filename_prefix_for_split
from .splits import Split, SplitDict, SplitGenerator
from .utils import logging
from .utils.download_manager import DownloadManager, DownloadMode
from .utils.file_utils import DownloadConfig, is_remote_url
from .utils.file_utils import DownloadConfig, cached_path, is_remote_url
from .utils.filelock import FileLock
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits
from .utils.mock_download_manager import MockDownloadManager
from .utils.py_utils import (
classproperty,
has_sufficient_disk_space,
map_nested,
memoize,
size_str,
temporary_assignment,
)
from .utils.streaming_download_manager import StreamingDownloadManager


Expand Down Expand Up @@ -389,9 +396,9 @@ def _create_builder_config(self, name=None, custom_features=None, **config_kwarg

return builder_config, config_id

@utils.classproperty
@classproperty
@classmethod
@utils.memoize()
@memoize()
def builder_configs(cls):
"""Pre-defined list of configurations for this builder class."""
configs = {config.name: config for config in cls.BUILDER_CONFIGS}
Expand Down Expand Up @@ -537,9 +544,9 @@ def download_and_prepare(
return
logger.info(f"Generating dataset {self.name} ({self._cache_dir})")
if not is_remote_url(self._cache_dir_root): # if cache dir is local, check for available space
if not utils.has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
raise OSError(
f"Not enough disk space. Needed: {utils.size_str(self.info.size_in_bytes or 0)} (download: {utils.size_str(self.info.download_size or 0)}, generated: {utils.size_str(self.info.dataset_size or 0)}, post-processed: {utils.size_str(self.info.post_processing_size or 0)})"
f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})"
)

@contextlib.contextmanager
Expand All @@ -565,9 +572,9 @@ def incomplete_dir(dirname):
if self.info.size_in_bytes:
print(
f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} "
f"(download: {utils.size_str(self.info.download_size)}, generated: {utils.size_str(self.info.dataset_size)}, "
f"post-processed: {utils.size_str(self.info.post_processing_size)}, "
f"total: {utils.size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, "
f"post-processed: {size_str(self.info.post_processing_size)}, "
f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
)
else:
print(
Expand All @@ -580,7 +587,7 @@ def incomplete_dir(dirname):
with incomplete_dir(self._cache_dir) as tmp_data_dir:
# Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with utils.temporary_assignment(self, "_cache_dir", tmp_data_dir):
with temporary_assignment(self, "_cache_dir", tmp_data_dir):
# Try to download the already prepared dataset files
downloaded_from_gcs = False
if try_from_hf_gcs:
Expand Down Expand Up @@ -637,7 +644,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig):
if os.sep in resource_file_name:
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
try:
resource_path = utils.cached_path(remote_cache_dir + "/" + resource_file_name)
resource_path = cached_path(remote_cache_dir + "/" + resource_file_name)
shutil.move(resource_path, os.path.join(self._cache_dir, resource_file_name))
except ConnectionError:
logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.")
Expand Down Expand Up @@ -761,7 +768,7 @@ def as_dataset(
split = {s: s for s in self.info.splits}

# Create a dataset for each of the given splits
datasets = utils.map_nested(
datasets = map_nested(
partial(
self._build_single_dataset,
run_post_process=run_post_process,
Expand Down Expand Up @@ -903,7 +910,7 @@ def as_streaming_dataset(
raise ValueError(f"Bad split: {split}. Available splits: {list(splits_generators)}")

# Create a dataset for each of the given splits
datasets = utils.map_nested(
datasets = map_nested(
self._as_streaming_dataset_single,
splits_generator,
map_tuple=True,
Expand Down Expand Up @@ -1074,7 +1081,7 @@ def _prepare_split(self, split_generator):
check_duplicates=True,
) as writer:
try:
for key, record in utils.tqdm(
for key, record in utils.tqdm_utils.tqdm(
generator,
unit=" examples",
total=split_info.num_examples,
Expand Down Expand Up @@ -1135,7 +1142,7 @@ def _prepare_split(self, split_generator):

generator = self._generate_tables(**split_generator.gen_kwargs)
with ArrowWriter(features=self.info.features, path=fpath) as writer:
for key, table in utils.tqdm(
for key, table in utils.tqdm_utils.tqdm(
generator, unit=" tables", leave=False, disable=True # not utils.is_progress_bar_enabled()
):
writer.write_table(table)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/commands/dummy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from datasets import config
from datasets.commands import BaseDatasetsCLICommand
from datasets.load import dataset_module_factory, import_main_class
from datasets.utils import MockDownloadManager
from datasets.utils.download_manager import DownloadManager
from datasets.utils.file_utils import DownloadConfig
from datasets.utils.logging import get_logger, set_verbosity_warning
from datasets.utils.mock_download_manager import MockDownloadManager
from datasets.utils.py_utils import map_nested


Expand Down
3 changes: 1 addition & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from fsspec.implementations.local import LocalFileSystem
from tqdm.contrib.concurrent import thread_map

from datasets.filesystems.hffilesystem import HfFileSystem

from .filesystems.hffilesystem import HfFileSystem
from .splits import Split
from .utils import logging
from .utils.file_utils import hf_hub_url, is_remote_url, request_etag
Expand Down
26 changes: 17 additions & 9 deletions src/datasets/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# flake8: noqa

__all__ = [
"Audio",
"Array2D",
"Array3D",
"Array4D",
"Array5D",
"ClassLabel",
"Features",
"Sequence",
"Value",
"Image",
"Translation",
"TranslationVariableLanguages",
]
from .audio import Audio
from .features import *
from .features import (
_ArrayXD,
_ArrayXDExtensionType,
_arrow_to_datasets_dtype,
_cast_to_python_objects,
_is_zero_copy_only,
)
from .image import Image, objects_to_list_of_image_dicts
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
Loading