Skip to content

Commit 8855849

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make inspect_array_sharding inside shard_map work with check_vma=True | check_rep=True. Fixes #23936
PiperOrigin-RevId: 750987053
1 parent 3bc8436 commit 8855849

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

jax/_src/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,7 +1894,10 @@ def str_short_aval(shape, dtype, mesh, spec, vma,
18941894
def get_vma(vma, mesh):
18951895
if mesh.empty:
18961896
return vma
1897+
axis_env_names = get_axis_env().axis_names()
18971898
for i in vma:
1899+
if i in axis_env_names and i not in mesh._name_to_type:
1900+
continue
18981901
if mesh._name_to_type[i] != AxisType.Manual:
18991902
raise ValueError(
19001903
"Axes mentioned in `vma` field of ShapedArray should"

jax/_src/debugging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
459459
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
460460
am.axis_names)
461461
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
462+
mesh = axis_context.mesh
462463
devices = axis_context.mesh._flat_devices_tuple
463464
else:
464465
raise NotImplementedError(type(axis_context))
@@ -470,7 +471,8 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding):
470471
if mesh.empty:
471472
return callback(
472473
sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices))
473-
pspec = parse_flatten_op_sharding(hlo_sharding, mesh)[0]
474+
pspec = (P() if hlo_sharding.is_manual() else
475+
parse_flatten_op_sharding(hlo_sharding, mesh)[0])
474476
return callback(NamedSharding(mesh, pspec))
475477

476478
if len(devices) == 1:

tests/debugging_primitives_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax._src import debugging
2626
from jax._src import dispatch
2727
from jax._src import test_util as jtu
28+
from jax.sharding import PartitionSpec as P
2829
import jax.numpy as jnp
2930
import numpy as np
3031

@@ -1120,6 +1121,28 @@ def test_visualize_pmap_sharding(self):
11201121
""")
11211122
self.assertEqual(output(), expected)
11221123

1124+
def test_visualize_sharding_shard_map(self):
1125+
mesh = jtu.create_mesh((2,), 'x')
1126+
1127+
def f():
1128+
a = jnp.zeros(1000)
1129+
debugging.visualize_array_sharding(a)
1130+
return a
1131+
1132+
with jtu.capture_stdout() as output:
1133+
f() # doesn't crash
1134+
1135+
with jtu.capture_stdout() as output:
1136+
jax.jit(f, out_shardings=jax.NamedSharding(mesh, P('x')))() # doesn't crash
1137+
1138+
with jtu.capture_stdout() as output:
1139+
jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"))() # doesn't crash
1140+
1141+
with jtu.capture_stdout() as output:
1142+
jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"),
1143+
check_vma=False)() # doesn't crash
1144+
1145+
11231146
class InspectShardingTest(jtu.JaxTestCase):
11241147

11251148
def test_inspect_sharding_is_called_in_pjit(self):

0 commit comments

Comments
 (0)