Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/add_metric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The most important attributes to specify are:
- :attr:`datasets.MetricInfo.description`: a :obj:`str` describing the metric,
- :attr:`datasets.MetricInfo.citation`: a :obj:`str` containing the citation for the metric in a BibTex format for inclusion in communications citing the metric,
- :attr:`datasets.MetricInfo.homepage`: a :obj:`str` containing an URL to an original homepage of the metric.
- :attr:`datasets.MetricInfo.format`: an optional :obj:`str` to tell what is the format of the predictions and the references passed to the :func:`datasets.DatasetBuilder._compute` method. It can be set to "numpy", "torch", "tensorflow" or "pandas".
- :attr:`datasets.MetricInfo.format`: an optional :obj:`str` to tell what is the format of the predictions and the references passed to the :func:`datasets.DatasetBuilder._compute` method. It can be set to "numpy", "torch", "tensorflow", "jax" or "pandas".

Here is for instance the :func:`datasets.Metric._info` for the Sacrebleu metric, which is taken from the `sacrebleu metric loading script <https://github.com/huggingface/datasets/tree/master/metrics/sacrebleu/sacrebleu.py>`__:

Expand Down
8 changes: 4 additions & 4 deletions docs/source/exploring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ In particular, you can easily select a specific column in batches, and also natu
True


Working with NumPy, pandas, PyTorch, TensorFlow and on-the-fly formatting transforms
------------------------------------------------------------------------------------
Working with NumPy, pandas, PyTorch, TensorFlow, JAX and on-the-fly formatting transforms
-----------------------------------------------------------------------------------------------------

Up to now, the rows/batches/columns returned when querying the elements of the dataset were python objects.

Expand All @@ -198,10 +198,10 @@ A specific format can be activated with :func:`datasets.Dataset.set_format`.

:func:`datasets.Dataset.set_format` accepts those inputs to control the format of the dataset:

- :obj:`type` (``Union[None, str]``, default to ``None``) defines the return type for the dataset :obj:`__getitem__` method and is one of ``[None, 'numpy', 'pandas', 'torch', 'tensorflow']`` (``None`` means return python objects),
- :obj:`type` (``Union[None, str]``, default to ``None``) defines the return type for the dataset :obj:`__getitem__` method and is one of ``[None, 'numpy', 'pandas', 'torch', 'tensorflow', 'jax']`` (``None`` means return python objects),
- :obj:`columns` (``Union[None, str, List[str]]``, default to ``None``) defines the columns returned by :obj:`__getitem__` and takes the name of a column in the dataset or a list of columns to return (``None`` means return all columns),
- :obj:`output_all_columns` (``bool``, default to ``False``) controls whether the columns which cannot be formatted (e.g. a column with ``string`` cannot be cast in a PyTorch Tensor) are still outputted as python objects.
- :obj:`format_kwargs` can be used to provide additional keywords arguments that will be forwarded to the convertiong function like ``np.array``, ``torch.tensor`` or ``tensorflow.ragged.constant``. For instance, to create ``torch.Tensor`` directly on the GPU you can specify ``device='cuda'``.
- :obj:`format_kwargs` can be used to provide additional keywords arguments that will be forwarded to the convertiong function like ``np.array``, ``torch.tensor``, ``tensorflow.ragged.constant`` or ``jnp.array``. For instance, to create ``torch.Tensor`` directly on the GPU you can specify ``device='cuda'``.

.. note::

Expand Down
2 changes: 1 addition & 1 deletion docs/source/quicktour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Now that we have encoded our dataset, we want to use it in a ``torch.Dataloader`
To be able to train our model with this dataset and PyTorch, we will need to do three modifications:

- rename our ``label`` column in ``labels`` which is the expected input name for labels in `BertForSequenceClassification <https://huggingface.co/transformers/model_doc/bert.html?#transformers.BertForSequenceClassification.forward>`__ or `TFBertForSequenceClassification <https://huggingface.co/transformers/model_doc/bert.html?#tfbertforsequenceclassification>`__,
- get pytorch (or tensorflow) tensors out of our :class:`datasets.Dataset`, instead of python objects, and
- get pytorch (or tensorflow, or jax) tensors out of our :class:`datasets.Dataset`, instead of python objects, and
- filter the columns to return only the subset of the columns that we need for our model inputs (``input_ids``, ``token_type_ids`` and ``attention_mask``).

.. note::
Expand Down
3 changes: 2 additions & 1 deletion docs/source/torch_tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format of a :class:`datasets.Dataset` instance can be set using the :func:`d
- ``None``/``'python'`` (default): return python objects,
- ``'torch'``/``'pytorch'``/``'pt'``: return PyTorch tensors,
- ``'tensorflow'``/``'tf'``: return Tensorflow tensors,
- ``'jax'``: return JAX arrays,
- ``'numpy'``/``'np'``: return Numpy arrays,
- ``'pandas'``/``'pd'``: return Pandas DataFrames.

Expand Down Expand Up @@ -80,7 +81,7 @@ Here is how we can apply a format to a simple dataset using :func:`datasets.Data

In this examples we filtered out the string columns `sentence1` and `sentence2` since they cannot be converted easily as tensors (at least in PyTorch). As detailed above, we could still output them as python object by setting ``output_all_columns=True``.

We can also pass ``**kwargs`` to the respective convert functions like ``np.array``, ``torch.tensor`` or ``tensorflow.ragged.constant`` by adding keyword arguments to :func:`datasets.Dataset.set_format()`. For example, if we want the columns formatted as PyTorch CUDA tensors, we use the following:
We can also pass ``**kwargs`` to the respective convert functions like ``np.array``, ``torch.tensor``, ``tensorflow.ragged.constant`` or ``jnp.array`` by adding keyword arguments to :func:`datasets.Dataset.set_format()`. For example, if we want the columns formatted as PyTorch CUDA tensors, we use the following:

.. code-block::

Expand Down
16 changes: 16 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_JAX", "AUTO").upper()

TORCH_VERSION = "N/A"
TORCH_AVAILABLE = False
Expand Down Expand Up @@ -87,6 +88,21 @@
logger.info("Disabling Tensorflow because USE_TORCH is set")


JAX_VERSION = "N/A"
JAX_AVAILABLE = False

if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
JAX_AVAILABLE = importlib.util.find_spec("jax") is not None
if JAX_AVAILABLE:
try:
JAX_VERSION = importlib_metadata.version("jax")
logger.info(f"JAX version {JAX_VERSION} available.")
except importlib_metadata.PackageNotFoundError:
pass
else:
logger.info("Disabling JAX because USE_JAX is set to False")


USE_BEAM = os.environ.get("USE_BEAM", "AUTO").upper()
BEAM_VERSION = "N/A"
BEAM_AVAILABLE = False
Expand Down
14 changes: 10 additions & 4 deletions src/datasets/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
""" This class handle features definition in datasets and some utilities to display table type."""
import copy
import re
import sys
from collections.abc import Iterable
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Dict, List, Optional
Expand Down Expand Up @@ -158,18 +159,23 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
has_changed (bool): True if the object has been changed, False if it is identical
"""

if config.TF_AVAILABLE:
if config.TF_AVAILABLE and "tensorflow" in sys.modules:
import tensorflow as tf

if config.TORCH_AVAILABLE:
if config.TORCH_AVAILABLE and "torch" in sys.modules:
import torch

if config.JAX_AVAILABLE and "jax" in sys.modules:
import jax.numpy as jnp

if isinstance(obj, np.ndarray):
return obj.tolist(), True
elif config.TORCH_AVAILABLE and isinstance(obj, torch.Tensor):
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
return obj.detach().cpu().numpy().tolist(), True
elif config.TF_AVAILABLE and isinstance(obj, tf.Tensor):
elif config.TF_AVAILABLE and "tensorflow" in sys.modules and isinstance(obj, tf.Tensor):
return obj.numpy().tolist(), True
elif config.JAX_AVAILABLE and "jax" in sys.modules and isinstance(obj, jnp.ndarray):
return obj.tolist(), True
elif isinstance(obj, pd.Series):
return obj.values.tolist(), True
elif isinstance(obj, pd.DataFrame):
Expand Down
8 changes: 8 additions & 0 deletions src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def _register_unavailable_formatter(
_tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.")
_register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"])

if config.JAX_AVAILABLE:
from .jax_formatter import JaxFormatter

_register_formatter(JaxFormatter, "jax", aliases=[])
else:
_jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
_register_unavailable_formatter(_jax_error, "jax", aliases=[])


def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]:
"""If the given format type is a known alias, then return its main type name. Otherwise return the type with no change."""
Expand Down
75 changes: 75 additions & 0 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
from typing import TYPE_CHECKING

import numpy as np
import pyarrow as pa

from ..utils.py_utils import map_nested
from .formatting import Formatter


if TYPE_CHECKING:
import jax.numpy as jnp


class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]):
def __init__(self, **jnp_array_kwargs):
self.jnp_array_kwargs = jnp_array_kwargs
import jax.numpy as jnp # noqa import jax at initialization

def _tensorize(self, value):
import jax
import jax.numpy as jnp

default_dtype = {}
if np.issubdtype(value.dtype, np.integer):
# the default int precision depends on the jax config
# see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
if jax.config.jax_enable_x64:
default_dtype = {"dtype": jnp.int64}
else:
default_dtype = {"dtype": jnp.int32}
elif np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": jnp.float32}

# calling jnp.array on a np.ndarray does copy the data
# see https://github.com/google/jax/issues/4486
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})

def _recursive_tensorize(self, data_struct: dict):
# support for nested types like struct of list of struct
if isinstance(data_struct, (list, np.ndarray)):
data_struct = np.array(data_struct, copy=False)
if data_struct.dtype == np.object: # jax arrays cannot be instantied from an array of objects
return [self.recursive_tensorize(substruct) for substruct in data_struct]
return self._tensorize(data_struct)

def recursive_tensorize(self, data_struct: dict):
return map_nested(self._recursive_tensorize, data_struct, map_list=False)

def format_row(self, pa_table: pa.Table) -> dict:
row = self.numpy_arrow_extractor().extract_row(pa_table)
return self.recursive_tensorize(row)

def format_column(self, pa_table: pa.Table) -> "jnp.ndarray":
col = self.numpy_arrow_extractor().extract_column(pa_table)
return self.recursive_tensorize(col)

def format_batch(self, pa_table: pa.Table) -> dict:
batch = self.numpy_arrow_extractor().extract_batch(pa_table)
return self.recursive_tensorize(batch)
2 changes: 2 additions & 0 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> st
ua += "; torch/{}".format(config.TORCH_VERSION)
if config.TF_AVAILABLE:
ua += "; tensorflow/{}".format(config.TF_VERSION)
if config.JAX_AVAILABLE:
ua += "; jax/{}".format(config.JAX_VERSION)
if config.BEAM_AVAILABLE:
ua += "; apache_beam/{}".format(config.BEAM_VERSION)
if isinstance(user_agent, dict):
Expand Down
14 changes: 13 additions & 1 deletion tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from datasets.info import DatasetInfo

from .utils import require_tf, require_torch
from .utils import require_jax, require_tf, require_torch


class FeaturesTest(TestCase):
Expand Down Expand Up @@ -294,6 +294,18 @@ def test_cast_to_python_objects_tf(self):
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)

@require_jax
def test_cast_to_python_objects_jax(self):
import jax.numpy as jnp

obj = {
"col_1": [{"vec": jnp.array(np.arange(1, 4)), "txt": "foo"}] * 3,
"col_2": jnp.array(np.arange(1, 7).reshape(3, 2)),
}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)

@patch("datasets.features._cast_to_python_objects", side_effect=_cast_to_python_objects)
def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast):
obj = {"col_1": [[1, 2], [3, 4], [5, 6]]}
Expand Down
35 changes: 34 additions & 1 deletion tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datasets.formatting.formatting import NumpyArrowExtractor, PandasArrowExtractor, PythonArrowExtractor
from datasets.table import InMemoryTable

from .utils import require_tf, require_torch
from .utils import require_jax, require_tf, require_torch


_COL_A = [0, 1, 2]
Expand Down Expand Up @@ -195,6 +195,39 @@ def test_tf_formatter_np_array_kwargs(self):
self.assertEqual(batch["a"].dtype, tf.float16)
self.assertEqual(batch["c"].dtype, tf.float16)

@require_jax
def test_jax_formatter(self):
import jax.numpy as jnp

from datasets.formatting import JaxFormatter

pa_table = self._create_dummy_table().drop(["b"])
formatter = JaxFormatter()
row = formatter.format_row(pa_table)
jnp.allclose(row["a"], jnp.array(_COL_A, dtype=jnp.int64)[0])
jnp.allclose(row["c"], jnp.array(_COL_C, dtype=jnp.float32)[0])
col = formatter.format_column(pa_table)
jnp.allclose(col, jnp.array(_COL_A, dtype=jnp.int64))
batch = formatter.format_batch(pa_table)
jnp.allclose(batch["a"], jnp.array(_COL_A, dtype=jnp.int64))
jnp.allclose(batch["c"], jnp.array(_COL_C, dtype=jnp.float32))

@require_jax
def test_jax_formatter_np_array_kwargs(self):
import jax.numpy as jnp

from datasets.formatting import JaxFormatter

pa_table = self._create_dummy_table().drop(["b"])
formatter = JaxFormatter(dtype=jnp.float16)
row = formatter.format_row(pa_table)
self.assertEqual(row["c"].dtype, jnp.float16)
col = formatter.format_column(pa_table)
self.assertEqual(col.dtype, jnp.float16)
batch = formatter.format_batch(pa_table)
self.assertEqual(batch["a"].dtype, jnp.float16)
self.assertEqual(batch["c"].dtype, jnp.float16)


class QueryTest(TestCase):
def _create_dummy_table(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def require_tf(test_case):
return test_case


def require_jax(test_case):
"""
Decorator marking a test that requires JAX.

These tests are skipped when JAX isn't installed.

"""
if not config.JAX_AVAILABLE:
test_case = unittest.skip("test requires JAX")(test_case)
return test_case


def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.
Expand Down