Skip to content
Discussion options

You must be logged in to vote

JAX does not support string data types, so there's no way to do this directly. But one pattern you could use is to create a data structure with a static list of strings along with a dynamic list of indices into it. This could then be passed to JIT and other transformations and manipulated, then the numpy string array could be recovered afterward. Here's a small example that supports slicing:

import dataclasses
import numpy as np
import jax
import jax.numpy as jnp


@jax.tree_util.register_dataclass
@dataclasses.dataclass
class StringArray:
  labels: tuple[str, ...] = dataclasses.field(metadata=dict(static=True))
  data: jax.Array

  @classmethod
  def from_numpy(cls, data: np.ndarray | list[

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by fishjojo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants