-
Notifications
You must be signed in to change notification settings - Fork 3.1k
JAX integration #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
JAX integration #2502
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2020 The HuggingFace Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Lint as: python3 | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| import pyarrow as pa | ||
|
|
||
| from ..utils.py_utils import map_nested | ||
| from .formatting import Formatter | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import jax.numpy as jnp | ||
|
|
||
|
|
||
| class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]): | ||
| def __init__(self, **jnp_array_kwargs): | ||
| self.jnp_array_kwargs = jnp_array_kwargs | ||
| import jax.numpy as jnp # noqa import jax at initialization | ||
|
|
||
| def _tensorize(self, value): | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| default_dtype = {} | ||
| if np.issubdtype(value.dtype, np.integer): | ||
| # the default int precision depends on the jax config | ||
| # see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision | ||
| if jax.config.jax_enable_x64: | ||
| default_dtype = {"dtype": jnp.int64} | ||
| else: | ||
| default_dtype = {"dtype": jnp.int32} | ||
| elif np.issubdtype(value.dtype, np.floating): | ||
| default_dtype = {"dtype": jnp.float32} | ||
|
|
||
| # calling jnp.array on a np.ndarray does copy the data | ||
| # see https://github.com/google/jax/issues/4486 | ||
| return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs}) | ||
|
|
||
| def _recursive_tensorize(self, data_struct: dict): | ||
| # support for nested types like struct of list of struct | ||
| if isinstance(data_struct, (list, np.ndarray)): | ||
| data_struct = np.array(data_struct, copy=False) | ||
| if data_struct.dtype == np.object: # jax arrays cannot be instantied from an array of objects | ||
| return [self.recursive_tensorize(substruct) for substruct in data_struct] | ||
| return self._tensorize(data_struct) | ||
|
|
||
| def recursive_tensorize(self, data_struct: dict): | ||
| return map_nested(self._recursive_tensorize, data_struct, map_list=False) | ||
|
|
||
| def format_row(self, pa_table: pa.Table) -> dict: | ||
| row = self.numpy_arrow_extractor().extract_row(pa_table) | ||
| return self.recursive_tensorize(row) | ||
|
|
||
| def format_column(self, pa_table: pa.Table) -> "jnp.ndarray": | ||
| col = self.numpy_arrow_extractor().extract_column(pa_table) | ||
| return self.recursive_tensorize(col) | ||
|
|
||
| def format_batch(self, pa_table: pa.Table) -> dict: | ||
| batch = self.numpy_arrow_extractor().extract_batch(pa_table) | ||
| return self.recursive_tensorize(batch) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.