Skip to content

Commit b53f757

Browse files
author
jax authors
committed
Merge pull request #19667 from jakevdp:array-empty-repr
PiperOrigin-RevId: 604424058
2 parents c1c0c1c + d9cbd7b commit b53f757

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

jax/_src/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,11 @@ def __repr__(self):
376376

377377
if self.is_fully_addressable or self.is_fully_replicated:
378378
line_width = np.get_printoptions()["linewidth"]
379-
s = np.array2string(self._value, prefix=prefix, suffix=',',
380-
separator=', ', max_line_width=line_width)
379+
if self.size == 0:
380+
s = f"[], shape={self.shape}"
381+
else:
382+
s = np.array2string(self._value, prefix=prefix, suffix=',',
383+
separator=', ', max_line_width=line_width)
381384
last_line_len = len(s) - s.rfind('\n') + 1
382385
sep = ' '
383386
if last_line_len + len(dtype_str) + 1 > line_width:

tests/array_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ def test_repr(self):
231231
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
232232
self.assertStartsWith(repr(arr), "Array(")
233233

234+
def test_empty_repr(self):
235+
shape = (0, 5)
236+
dtype = 'float32'
237+
x = jnp.empty(shape, dtype)
238+
self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})")
239+
234240
def test_jnp_array(self):
235241
arr = jnp.array([1, 2, 3])
236242
self.assertIsInstance(arr, array.ArrayImpl)

0 commit comments

Comments
 (0)