-
|
Given this MWE: import jax
import jax.numpy as jnp
def get_expensive_scalar(arg1, arg2):
"""Expensive computation that eventually returns a scalar"""
return arg1 + arg2
def _add_scalar(arr_elem, arg1, arg2):
"""The function we will vmap"""
scalar = get_expensive_scalar(arg1, arg2) # does this run for each element of `vec_in` ?
result = arr_elem + scalar
return result
add_scalar = jax.jit(jax.vmap(_add_scalar, in_axes=[0, None, None]))
vec_in = jnp.ones(100)
add_scalar(vec_in, 1, 2)Does jax understand that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
|
The most definitive way to answer this is to use Ahead of time compilation in order to print the compiled HLO for your function. For example: print(add_scalar.lower(vec_in, 1, 2).compile().as_text())This is admittedly a bit tricky to read, but you can see in this line that the scalars are added only once before the result is broadcast to the larger shape to be added to the batched input. There is one detail here: this part of the computation takes place within a fusion, meaning that the compiler may rearrange operations if it is deemed advantageous to do so; that said if your operation really is something expensive (where I hope that helps! |
Beta Was this translation helpful? Give feedback.
The most definitive way to answer this is to use Ahead of time compilation in order to print the compiled HLO for your function. For example: