Skip to content

Conversation

@lhoestq
Copy link
Member

@lhoestq lhoestq commented Jun 14, 2021

Hi !

I just added the "jax" formatting, as we already have for pytorch, tensorflow, numpy (and also pandas and arrow).
It does pretty much the same thing as the pytorch formatter except it creates jax.numpy.ndarray objects.

from datasets import Dataset

d = Dataset.from_dict({"foo": [[0., 1., 2.]]})
d = d.with_format("jax")
d[0]
# {'foo': DeviceArray([0., 1., 2.], dtype=float32)}

A few details:

  • The default integer precision for jax depends on the jax configuration jax_enable_x64 (see here), I took that into account. Unless jax_enable_x64 is specified, it is int32 by default
  • AFAIK it's not possible to do a full conversion from arrow data to jax data. We are doing arrow -> numpy -> jax but the numpy -> jax part doesn't do zero copy unfortutanely (see here)
  • the env var for disabling JAX is USE_JAX. However I noticed that in transformers it is USE_FLAX. This is not an issue though IMO

I also updated convert_to_python_objects to allow users to pass jax.numpy.ndarray objects to build a dataset.

Since the convert_to_python_objects method became slow because it's the time when pytorch, tf (and now jax) are imported, I fixed it by checking the sys.modules to avoid unecessary import of pytorch, tf or jax.

Close #2495

@lhoestq lhoestq marked this pull request as ready for review June 14, 2021 17:44
@lhoestq lhoestq merged commit 537402c into master Jun 21, 2021
@lhoestq lhoestq deleted the jax-formatting branch June 21, 2021 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX formatting

3 participants