-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Closed
Copy link
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I really like using inspect_array_sharding and my own sharding_info() based on it, to better understand how sharding works in jax and what is going wrong in my code.
But recently I have been using shard_map more and there the inspect_array_sharding callback seems to be broken.
import os
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count=4'
import jax
import jax.numpy as jnp
from jax.debug import visualize_array_sharding
devices = mesh_utils.create_device_mesh((4,))
mesh = Mesh(devices, axis_names=('a',))
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('a'))
def some_function():
a = jnp.zeros(1000)
visualize_array_sharding(a)
return a
some_function()
# ┌───────┐
# │ CPU 0 │
# └───────┘
some_function_jitted = jax.jit(some_function, out_shardings=sharding)
some_function_jitted()
# ┌───────┬───────┬───────┬───────┐
# │ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │
# └───────┴───────┴───────┴───────┘
some_function_shard_map = shard_map(
some_function,
mesh=mesh,
in_specs=PartitionSpec(None),
out_specs=PartitionSpec("a"),
# check_rep=False
)
some_function_shard_map()This causes the following issue:
Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
some_function_shard_map()
File ".../inspect_shard_map.py", line 22, in some_function
visualize_array_sharding(a)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
inspect_array_sharding(arr, callback=_visualize)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
tree_util.tree_map(_inspect, value)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
inspect_sharding_p.bind(val, callback=callback)
NotImplementedError: No replication rule for inspect_sharding. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues
But even with check_rep=False added, the code still fails:
Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
some_function_shard_map()
File ".../inspect_shard_map.py", line 22, in some_function
visualize_array_sharding(a)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
inspect_array_sharding(arr, callback=_visualize)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
tree_util.tree_map(_inspect, value)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
inspect_sharding_p.bind(val, callback=callback)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error calling inspect_sharding: Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 191, in wrapped
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 473, in bind
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 745, in _shard_map_impl
File "venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
File ".../inspect_shard_map.py", line 22, in some_function
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in tree_map
File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in <genexpr>
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 439, in bind
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 835, in process_primitive
File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
File "venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 266, in backend_compile
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 385, in _hlo_sharding_callback
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 626, in _visualize
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 496, in visualize_sharding
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 200, in devices_indices_map
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 49, in common_devices_indices_map
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 217, in shard_shape
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 58, in _common_shard_shape
File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 745, in _to_xla_hlo_sharding
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 617, in _positional_sharding_to_xla_hlo_sharding
ValueError: not enough values to unpack (expected 1, got 0)
jax/jax/_src/sharding_impls.py
Line 617 in 1594d2f
| set_size, = {len(device_set) for device_set in self._ids.flat} |
It would be great if inspect_array_sharding could work inside a shard_map the same way it already does inside a jitted function.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 2.0.1
python: 3.12.6 (main, Sep 7 2024, 14:20:15) [GCC 14.2.0]
jax.devices (4 total, 4 local): [CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='standpc', release='6.10.9-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.9-1 (2024-09-08)', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working