-
|
I have a very simple python script as follows: However, I'm not able to convert it to the Jax compatible version with the indices being traced, e.g., won't work. Is there any workaround for this? Thank you in advance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
JAX does not support string data types, so there's no way to do this directly. But one pattern you could use is to create a data structure with a static list of strings along with a dynamic list of indices into it. This could then be passed to JIT and other transformations and manipulated, then the numpy string array could be recovered afterward. Here's a small example that supports slicing: import dataclasses
import numpy as np
import jax
import jax.numpy as jnp
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class StringArray:
labels: tuple[str, ...] = dataclasses.field(metadata=dict(static=True))
data: jax.Array
@classmethod
def from_numpy(cls, data: np.ndarray | list[str]):
data = np.asarray(data, dtype=str)
labels, values = np.unique(data, return_inverse=True)
return cls(
labels=tuple(map(str, labels)),
data=jnp.asarray(values).reshape(data.shape),
)
def __array__(self):
return np.asarray(self.labels)[np.asarray(self.data)]
def __getitem__(self, item):
return self.__class__(self.labels, self.data[item])
def __repr__(self):
return f"StringArray({np.array(self)!s})"
@jax.jit
def get_symbols(symbols, idx):
return symbols[idx]
arr = StringArray.from_numpy(['a', 'b', 'c', 'd'])
idx = np.array([1, 3])
print(arr)
print(jax.jit(get_symbols)(arr, idx))Depending on what you want to do with the string values within JIT, this kind of approach may or may not work. For example, doing string operations under JIT would be tricky. But if all you're after is to be able to slice a string array within JIT, it may be sufficient. |
Beta Was this translation helpful? Give feedback.
JAX does not support string data types, so there's no way to do this directly. But one pattern you could use is to create a data structure with a static list of strings along with a dynamic list of indices into it. This could then be passed to JIT and other transformations and manipulated, then the numpy string array could be recovered afterward. Here's a small example that supports slicing: