Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi !
I just added the "jax" formatting, as we already have for pytorch, tensorflow, numpy (and also pandas and arrow).
It does pretty much the same thing as the pytorch formatter except it creates jax.numpy.ndarray objects.
A few details:
jax_enable_x64(see here), I took that into account. Unlessjax_enable_x64is specified, it is int32 by defaultUSE_JAX. However I noticed that intransformersit isUSE_FLAX. This is not an issue though IMOI also updated
convert_to_python_objectsto allow users to pass jax.numpy.ndarray objects to build a dataset.Since the
convert_to_python_objectsmethod became slow because it's the time when pytorch, tf (and now jax) are imported, I fixed it by checking thesys.modulesto avoid unecessary import of pytorch, tf or jax.Close #2495