Skip to content

visualize_array_sharding/inspect_array_sharding doesn't work with shard_map #23936

@Findus23

Description

@Findus23

Description

I really like using inspect_array_sharding and my own sharding_info() based on it, to better understand how sharding works in jax and what is going wrong in my code.

But recently I have been using shard_map more and there the inspect_array_sharding callback seems to be broken.

import os

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
from jax.debug import visualize_array_sharding

devices = mesh_utils.create_device_mesh((4,))

mesh = Mesh(devices, axis_names=('a',))
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('a'))

def some_function():
    a = jnp.zeros(1000)
    visualize_array_sharding(a)
    return a

some_function()

# ┌───────┐
# │ CPU 0 │
# └───────┘

some_function_jitted = jax.jit(some_function, out_shardings=sharding)

some_function_jitted()

# ┌───────┬───────┬───────┬───────┐
# │ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │
# └───────┴───────┴───────┴───────┘

some_function_shard_map = shard_map(
    some_function,
    mesh=mesh,
    in_specs=PartitionSpec(None),
    out_specs=PartitionSpec("a"),
    # check_rep=False
)

some_function_shard_map()

This causes the following issue:

Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
    some_function_shard_map()
  File ".../inspect_shard_map.py", line 22, in some_function
    visualize_array_sharding(a)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
    inspect_array_sharding(arr, callback=_visualize)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
    tree_util.tree_map(_inspect, value)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
    inspect_sharding_p.bind(val, callback=callback)
NotImplementedError: No replication rule for inspect_sharding. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues

But even with check_rep=False added, the code still fails:

Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
    some_function_shard_map()
  File ".../inspect_shard_map.py", line 22, in some_function
    visualize_array_sharding(a)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
    inspect_array_sharding(arr, callback=_visualize)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
    tree_util.tree_map(_inspect, value)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
    inspect_sharding_p.bind(val, callback=callback)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error calling inspect_sharding: Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
  File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 191, in wrapped
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 473, in bind
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 745, in _shard_map_impl
  File "venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
  File ".../inspect_shard_map.py", line 22, in some_function
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in tree_map
  File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in <genexpr>
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 439, in bind
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 835, in process_primitive
  File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
  File "venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 266, in backend_compile
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 385, in _hlo_sharding_callback
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 626, in _visualize
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 496, in visualize_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 200, in devices_indices_map
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 49, in common_devices_indices_map
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 217, in shard_shape
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 58, in _common_shard_shape
  File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 745, in _to_xla_hlo_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 617, in _positional_sharding_to_xla_hlo_sharding
ValueError: not enough values to unpack (expected 1, got 0)

set_size, = {len(device_set) for device_set in self._ids.flat}

It would be great if inspect_array_sharding could work inside a shard_map the same way it already does inside a jitted function.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.0.1
python: 3.12.6 (main, Sep  7 2024, 14:20:15) [GCC 14.2.0]
jax.devices (4 total, 4 local): [CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='standpc', release='6.10.9-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.9-1 (2024-09-08)', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions