diff --git a/docs/source/add_metric.rst b/docs/source/add_metric.rst index 33d05104e25..0b4fa7d9aaa 100644 --- a/docs/source/add_metric.rst +++ b/docs/source/add_metric.rst @@ -31,7 +31,7 @@ The most important attributes to specify are: - :attr:`datasets.MetricInfo.description`: a :obj:`str` describing the metric, - :attr:`datasets.MetricInfo.citation`: a :obj:`str` containing the citation for the metric in a BibTex format for inclusion in communications citing the metric, - :attr:`datasets.MetricInfo.homepage`: a :obj:`str` containing an URL to an original homepage of the metric. -- :attr:`datasets.MetricInfo.format`: an optional :obj:`str` to tell what is the format of the predictions and the references passed to the :func:`datasets.DatasetBuilder._compute` method. It can be set to "numpy", "torch", "tensorflow" or "pandas". +- :attr:`datasets.MetricInfo.format`: an optional :obj:`str` to tell what is the format of the predictions and the references passed to the :func:`datasets.DatasetBuilder._compute` method. It can be set to "numpy", "torch", "tensorflow", "jax" or "pandas". Here is for instance the :func:`datasets.Metric._info` for the Sacrebleu metric, which is taken from the `sacrebleu metric loading script `__: diff --git a/docs/source/exploring.rst b/docs/source/exploring.rst index 61c75437fa0..736fd04aa20 100644 --- a/docs/source/exploring.rst +++ b/docs/source/exploring.rst @@ -183,8 +183,8 @@ In particular, you can easily select a specific column in batches, and also natu True -Working with NumPy, pandas, PyTorch, TensorFlow and on-the-fly formatting transforms ------------------------------------------------------------------------------------- +Working with NumPy, pandas, PyTorch, TensorFlow, JAX and on-the-fly formatting transforms +----------------------------------------------------------------------------------------------------- Up to now, the rows/batches/columns returned when querying the elements of the dataset were python objects. @@ -198,10 +198,10 @@ A specific format can be activated with :func:`datasets.Dataset.set_format`. :func:`datasets.Dataset.set_format` accepts those inputs to control the format of the dataset: -- :obj:`type` (``Union[None, str]``, default to ``None``) defines the return type for the dataset :obj:`__getitem__` method and is one of ``[None, 'numpy', 'pandas', 'torch', 'tensorflow']`` (``None`` means return python objects), +- :obj:`type` (``Union[None, str]``, default to ``None``) defines the return type for the dataset :obj:`__getitem__` method and is one of ``[None, 'numpy', 'pandas', 'torch', 'tensorflow', 'jax']`` (``None`` means return python objects), - :obj:`columns` (``Union[None, str, List[str]]``, default to ``None``) defines the columns returned by :obj:`__getitem__` and takes the name of a column in the dataset or a list of columns to return (``None`` means return all columns), - :obj:`output_all_columns` (``bool``, default to ``False``) controls whether the columns which cannot be formatted (e.g. a column with ``string`` cannot be cast in a PyTorch Tensor) are still outputted as python objects. -- :obj:`format_kwargs` can be used to provide additional keywords arguments that will be forwarded to the convertiong function like ``np.array``, ``torch.tensor`` or ``tensorflow.ragged.constant``. For instance, to create ``torch.Tensor`` directly on the GPU you can specify ``device='cuda'``. +- :obj:`format_kwargs` can be used to provide additional keywords arguments that will be forwarded to the convertiong function like ``np.array``, ``torch.tensor``, ``tensorflow.ragged.constant`` or ``jnp.array``. For instance, to create ``torch.Tensor`` directly on the GPU you can specify ``device='cuda'``. .. note:: diff --git a/docs/source/quicktour.rst b/docs/source/quicktour.rst index bd7e8d53687..664e769ac18 100644 --- a/docs/source/quicktour.rst +++ b/docs/source/quicktour.rst @@ -183,7 +183,7 @@ Now that we have encoded our dataset, we want to use it in a ``torch.Dataloader` To be able to train our model with this dataset and PyTorch, we will need to do three modifications: - rename our ``label`` column in ``labels`` which is the expected input name for labels in `BertForSequenceClassification `__ or `TFBertForSequenceClassification `__, -- get pytorch (or tensorflow) tensors out of our :class:`datasets.Dataset`, instead of python objects, and +- get pytorch (or tensorflow, or jax) tensors out of our :class:`datasets.Dataset`, instead of python objects, and - filter the columns to return only the subset of the columns that we need for our model inputs (``input_ids``, ``token_type_ids`` and ``attention_mask``). .. note:: diff --git a/docs/source/torch_tensorflow.rst b/docs/source/torch_tensorflow.rst index e2df1d553f7..abb24900646 100644 --- a/docs/source/torch_tensorflow.rst +++ b/docs/source/torch_tensorflow.rst @@ -22,6 +22,7 @@ The format of a :class:`datasets.Dataset` instance can be set using the :func:`d - ``None``/``'python'`` (default): return python objects, - ``'torch'``/``'pytorch'``/``'pt'``: return PyTorch tensors, - ``'tensorflow'``/``'tf'``: return Tensorflow tensors, + - ``'jax'``: return JAX arrays, - ``'numpy'``/``'np'``: return Numpy arrays, - ``'pandas'``/``'pd'``: return Pandas DataFrames. @@ -80,7 +81,7 @@ Here is how we can apply a format to a simple dataset using :func:`datasets.Data In this examples we filtered out the string columns `sentence1` and `sentence2` since they cannot be converted easily as tensors (at least in PyTorch). As detailed above, we could still output them as python object by setting ``output_all_columns=True``. -We can also pass ``**kwargs`` to the respective convert functions like ``np.array``, ``torch.tensor`` or ``tensorflow.ragged.constant`` by adding keyword arguments to :func:`datasets.Dataset.set_format()`. For example, if we want the columns formatted as PyTorch CUDA tensors, we use the following: +We can also pass ``**kwargs`` to the respective convert functions like ``np.array``, ``torch.tensor``, ``tensorflow.ragged.constant`` or ``jnp.array`` by adding keyword arguments to :func:`datasets.Dataset.set_format()`. For example, if we want the columns formatted as PyTorch CUDA tensors, we use the following: .. code-block:: diff --git a/src/datasets/config.py b/src/datasets/config.py index 9f5ac95a491..d1e6259259e 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -41,6 +41,7 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_JAX", "AUTO").upper() TORCH_VERSION = "N/A" TORCH_AVAILABLE = False @@ -87,6 +88,21 @@ logger.info("Disabling Tensorflow because USE_TORCH is set") +JAX_VERSION = "N/A" +JAX_AVAILABLE = False + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + JAX_AVAILABLE = importlib.util.find_spec("jax") is not None + if JAX_AVAILABLE: + try: + JAX_VERSION = importlib_metadata.version("jax") + logger.info(f"JAX version {JAX_VERSION} available.") + except importlib_metadata.PackageNotFoundError: + pass +else: + logger.info("Disabling JAX because USE_JAX is set to False") + + USE_BEAM = os.environ.get("USE_BEAM", "AUTO").upper() BEAM_VERSION = "N/A" BEAM_AVAILABLE = False diff --git a/src/datasets/features.py b/src/datasets/features.py index bb18cd80d96..b69a936fa3d 100644 --- a/src/datasets/features.py +++ b/src/datasets/features.py @@ -17,6 +17,7 @@ """ This class handle features definition in datasets and some utilities to display table type.""" import copy import re +import sys from collections.abc import Iterable from dataclasses import dataclass, field, fields from typing import Any, ClassVar, Dict, List, Optional @@ -158,18 +159,23 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]: has_changed (bool): True if the object has been changed, False if it is identical """ - if config.TF_AVAILABLE: + if config.TF_AVAILABLE and "tensorflow" in sys.modules: import tensorflow as tf - if config.TORCH_AVAILABLE: + if config.TORCH_AVAILABLE and "torch" in sys.modules: import torch + if config.JAX_AVAILABLE and "jax" in sys.modules: + import jax.numpy as jnp + if isinstance(obj, np.ndarray): return obj.tolist(), True - elif config.TORCH_AVAILABLE and isinstance(obj, torch.Tensor): + elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor): return obj.detach().cpu().numpy().tolist(), True - elif config.TF_AVAILABLE and isinstance(obj, tf.Tensor): + elif config.TF_AVAILABLE and "tensorflow" in sys.modules and isinstance(obj, tf.Tensor): return obj.numpy().tolist(), True + elif config.JAX_AVAILABLE and "jax" in sys.modules and isinstance(obj, jnp.ndarray): + return obj.tolist(), True elif isinstance(obj, pd.Series): return obj.values.tolist(), True elif isinstance(obj, pd.DataFrame): diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index d272a7fe888..325c4bd4866 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -93,6 +93,14 @@ def _register_unavailable_formatter( _tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.") _register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"]) +if config.JAX_AVAILABLE: + from .jax_formatter import JaxFormatter + + _register_formatter(JaxFormatter, "jax", aliases=[]) +else: + _jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.") + _register_unavailable_formatter(_jax_error, "jax", aliases=[]) + def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]: """If the given format type is a known alias, then return its main type name. Otherwise return the type with no change.""" diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py new file mode 100644 index 00000000000..ad1421040f6 --- /dev/null +++ b/src/datasets/formatting/jax_formatter.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +from typing import TYPE_CHECKING + +import numpy as np +import pyarrow as pa + +from ..utils.py_utils import map_nested +from .formatting import Formatter + + +if TYPE_CHECKING: + import jax.numpy as jnp + + +class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]): + def __init__(self, **jnp_array_kwargs): + self.jnp_array_kwargs = jnp_array_kwargs + import jax.numpy as jnp # noqa import jax at initialization + + def _tensorize(self, value): + import jax + import jax.numpy as jnp + + default_dtype = {} + if np.issubdtype(value.dtype, np.integer): + # the default int precision depends on the jax config + # see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + if jax.config.jax_enable_x64: + default_dtype = {"dtype": jnp.int64} + else: + default_dtype = {"dtype": jnp.int32} + elif np.issubdtype(value.dtype, np.floating): + default_dtype = {"dtype": jnp.float32} + + # calling jnp.array on a np.ndarray does copy the data + # see https://github.com/google/jax/issues/4486 + return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs}) + + def _recursive_tensorize(self, data_struct: dict): + # support for nested types like struct of list of struct + if isinstance(data_struct, (list, np.ndarray)): + data_struct = np.array(data_struct, copy=False) + if data_struct.dtype == np.object: # jax arrays cannot be instantied from an array of objects + return [self.recursive_tensorize(substruct) for substruct in data_struct] + return self._tensorize(data_struct) + + def recursive_tensorize(self, data_struct: dict): + return map_nested(self._recursive_tensorize, data_struct, map_list=False) + + def format_row(self, pa_table: pa.Table) -> dict: + row = self.numpy_arrow_extractor().extract_row(pa_table) + return self.recursive_tensorize(row) + + def format_column(self, pa_table: pa.Table) -> "jnp.ndarray": + col = self.numpy_arrow_extractor().extract_column(pa_table) + return self.recursive_tensorize(col) + + def format_batch(self, pa_table: pa.Table) -> dict: + batch = self.numpy_arrow_extractor().extract_batch(pa_table) + return self.recursive_tensorize(batch) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 84a38606efb..9e9f8dea8fc 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -375,6 +375,8 @@ def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> st ua += "; torch/{}".format(config.TORCH_VERSION) if config.TF_AVAILABLE: ua += "; tensorflow/{}".format(config.TF_VERSION) + if config.JAX_AVAILABLE: + ua += "; jax/{}".format(config.JAX_VERSION) if config.BEAM_AVAILABLE: ua += "; apache_beam/{}".format(config.BEAM_VERSION) if isinstance(user_agent, dict): diff --git a/tests/test_features.py b/tests/test_features.py index 9bb37b9b695..3c7730e35db 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -20,7 +20,7 @@ ) from datasets.info import DatasetInfo -from .utils import require_tf, require_torch +from .utils import require_jax, require_tf, require_torch class FeaturesTest(TestCase): @@ -294,6 +294,18 @@ def test_cast_to_python_objects_tf(self): casted_obj = cast_to_python_objects(obj) self.assertDictEqual(casted_obj, expected_obj) + @require_jax + def test_cast_to_python_objects_jax(self): + import jax.numpy as jnp + + obj = { + "col_1": [{"vec": jnp.array(np.arange(1, 4)), "txt": "foo"}] * 3, + "col_2": jnp.array(np.arange(1, 7).reshape(3, 2)), + } + expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]} + casted_obj = cast_to_python_objects(obj) + self.assertDictEqual(casted_obj, expected_obj) + @patch("datasets.features._cast_to_python_objects", side_effect=_cast_to_python_objects) def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast): obj = {"col_1": [[1, 2], [3, 4], [5, 6]]} diff --git a/tests/test_formatting.py b/tests/test_formatting.py index fc878d91a52..cca0cb127bc 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -9,7 +9,7 @@ from datasets.formatting.formatting import NumpyArrowExtractor, PandasArrowExtractor, PythonArrowExtractor from datasets.table import InMemoryTable -from .utils import require_tf, require_torch +from .utils import require_jax, require_tf, require_torch _COL_A = [0, 1, 2] @@ -195,6 +195,39 @@ def test_tf_formatter_np_array_kwargs(self): self.assertEqual(batch["a"].dtype, tf.float16) self.assertEqual(batch["c"].dtype, tf.float16) + @require_jax + def test_jax_formatter(self): + import jax.numpy as jnp + + from datasets.formatting import JaxFormatter + + pa_table = self._create_dummy_table().drop(["b"]) + formatter = JaxFormatter() + row = formatter.format_row(pa_table) + jnp.allclose(row["a"], jnp.array(_COL_A, dtype=jnp.int64)[0]) + jnp.allclose(row["c"], jnp.array(_COL_C, dtype=jnp.float32)[0]) + col = formatter.format_column(pa_table) + jnp.allclose(col, jnp.array(_COL_A, dtype=jnp.int64)) + batch = formatter.format_batch(pa_table) + jnp.allclose(batch["a"], jnp.array(_COL_A, dtype=jnp.int64)) + jnp.allclose(batch["c"], jnp.array(_COL_C, dtype=jnp.float32)) + + @require_jax + def test_jax_formatter_np_array_kwargs(self): + import jax.numpy as jnp + + from datasets.formatting import JaxFormatter + + pa_table = self._create_dummy_table().drop(["b"]) + formatter = JaxFormatter(dtype=jnp.float16) + row = formatter.format_row(pa_table) + self.assertEqual(row["c"].dtype, jnp.float16) + col = formatter.format_column(pa_table) + self.assertEqual(col.dtype, jnp.float16) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["a"].dtype, jnp.float16) + self.assertEqual(batch["c"].dtype, jnp.float16) + class QueryTest(TestCase): def _create_dummy_table(self): diff --git a/tests/utils.py b/tests/utils.py index 718c298097b..5f6206d31e4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -112,6 +112,18 @@ def require_tf(test_case): return test_case +def require_jax(test_case): + """ + Decorator marking a test that requires JAX. + + These tests are skipped when JAX isn't installed. + + """ + if not config.JAX_AVAILABLE: + test_case = unittest.skip("test requires JAX")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.