|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | # Lint as: python3 |
| 16 | +import sys |
16 | 17 | from typing import TYPE_CHECKING |
17 | 18 |
|
18 | 19 | import numpy as np |
19 | 20 | import pyarrow as pa |
20 | 21 |
|
| 22 | +from .. import config |
21 | 23 | from ..utils.py_utils import map_nested |
22 | 24 | from .formatting import Formatter |
23 | 25 |
|
|
28 | 30 |
|
29 | 31 | class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]): |
30 | 32 | def __init__(self, features=None, decoded=True, **jnp_array_kwargs): |
| 33 | + super().__init__(features=features, decoded=decoded) |
31 | 34 | self.jnp_array_kwargs = jnp_array_kwargs |
32 | 35 | import jax.numpy as jnp # noqa import jax at initialization |
33 | 36 |
|
| 37 | + def _consolidate(self, column): |
| 38 | + import jax.numpy as jnp |
| 39 | + |
| 40 | + if isinstance(column, list) and column: |
| 41 | + if all( |
| 42 | + isinstance(x, jnp.ndarray) and x.shape == column[0].shape and x.dtype == column[0].dtype |
| 43 | + for x in column |
| 44 | + ): |
| 45 | + return jnp.stack(column) |
| 46 | + return column |
| 47 | + |
34 | 48 | def _tensorize(self, value): |
35 | 49 | import jax |
36 | 50 | import jax.numpy as jnp |
37 | 51 |
|
| 52 | + if isinstance(value, (str, bytes, type(None))): |
| 53 | + return value |
| 54 | + elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): |
| 55 | + return value.tolist() |
| 56 | + |
38 | 57 | default_dtype = {} |
39 | | - if np.issubdtype(value.dtype, np.integer): |
| 58 | + |
| 59 | + if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer): |
40 | 60 | # the default int precision depends on the jax config |
41 | 61 | # see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision |
42 | 62 | if jax.config.jax_enable_x64: |
43 | 63 | default_dtype = {"dtype": jnp.int64} |
44 | 64 | else: |
45 | 65 | default_dtype = {"dtype": jnp.int32} |
46 | | - elif np.issubdtype(value.dtype, np.floating): |
| 66 | + elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): |
47 | 67 | default_dtype = {"dtype": jnp.float32} |
| 68 | + elif config.PIL_AVAILABLE and "PIL" in sys.modules: |
| 69 | + import PIL.Image |
| 70 | + |
| 71 | + if isinstance(value, PIL.Image.Image): |
| 72 | + value = np.asarray(value) |
48 | 73 |
|
49 | 74 | # calling jnp.array on a np.ndarray does copy the data |
50 | 75 | # see https://github.com/google/jax/issues/4486 |
51 | 76 | return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs}) |
52 | 77 |
|
53 | 78 | def _recursive_tensorize(self, data_struct: dict): |
54 | 79 | # support for nested types like struct of list of struct |
55 | | - if isinstance(data_struct, (list, np.ndarray)): |
56 | | - data_struct = np.array(data_struct, copy=False) |
| 80 | + if isinstance(data_struct, np.ndarray): |
57 | 81 | if data_struct.dtype == object: # jax arrays cannot be instantied from an array of objects |
58 | | - return [self.recursive_tensorize(substruct) for substruct in data_struct] |
| 82 | + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) |
59 | 83 | return self._tensorize(data_struct) |
60 | 84 |
|
61 | 85 | def recursive_tensorize(self, data_struct: dict): |
62 | | - return map_nested(self._recursive_tensorize, data_struct, map_list=False) |
| 86 | + return map_nested(self._recursive_tensorize, data_struct) |
63 | 87 |
|
64 | 88 | def format_row(self, pa_table: pa.Table) -> dict: |
65 | 89 | row = self.numpy_arrow_extractor().extract_row(pa_table) |
| 90 | + if self.decoded: |
| 91 | + row = self.python_features_decoder.decode_row(row) |
66 | 92 | return self.recursive_tensorize(row) |
67 | 93 |
|
68 | 94 | def format_column(self, pa_table: pa.Table) -> "jnp.ndarray": |
69 | | - col = self.numpy_arrow_extractor().extract_column(pa_table) |
70 | | - return self.recursive_tensorize(col) |
| 95 | + column = self.numpy_arrow_extractor().extract_column(pa_table) |
| 96 | + if self.decoded: |
| 97 | + column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) |
| 98 | + column = self.recursive_tensorize(column) |
| 99 | + column = self._consolidate(column) |
| 100 | + return column |
71 | 101 |
|
72 | 102 | def format_batch(self, pa_table: pa.Table) -> dict: |
73 | 103 | batch = self.numpy_arrow_extractor().extract_batch(pa_table) |
74 | | - return self.recursive_tensorize(batch) |
| 104 | + if self.decoded: |
| 105 | + batch = self.python_features_decoder.decode_batch(batch) |
| 106 | + batch = self.recursive_tensorize(batch) |
| 107 | + for column_name in batch: |
| 108 | + batch[column_name] = self._consolidate(batch[column_name]) |
| 109 | + return batch |
0 commit comments