Skip to content

Commit 3bc8436

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add __str__ to UnshapedArray so that whenever we print(aval), we don't see the class name by default. It only shows up when you do: repr(aval).
Weak_type in `__str__` is represented as `~int32[5, 4]` (note the tilde at the start) PiperOrigin-RevId: 751066142
1 parent d13ac0a commit 3bc8436

File tree

5 files changed

+13
-3
lines changed

5 files changed

+13
-3
lines changed

jax/_src/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def disable_jit(disable: bool = True):
358358
... return y + 3
359359
...
360360
>>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS
361-
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
361+
Value of y is Traced<int32[3]>with<DynamicJaxprTrace...>
362362
[5 7 9]
363363
364364
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,

jax/_src/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,9 @@ def __repr__(self):
16881688
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
16891689
", weak_type=True" if self.weak_type else "")
16901690

1691+
def __str__(self):
1692+
return '{}{}'.format("~" if self.weak_type else "", self.str_short())
1693+
16911694
_bool = concretization_function_error(bool)
16921695
_int = concretization_function_error(int, True)
16931696
_float = concretization_function_error(float, True)

jax/_src/tree_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ class Partial(functools.partial):
529529
>>> print_zero()
530530
0
531531
>>> call_func(print_zero) # doctest:+ELLIPSIS
532-
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
532+
Traced<~int32[]>with<DynamicJaxprTrace...>
533533
"""
534534

535535
def __new__(klass, func, *args, **kw):

tests/core_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,13 @@ def test_is_valid_jaxtype(self, dtype):
203203
else:
204204
self.assertFalse(core.valid_jaxtype(arr))
205205

206+
def test_str_aval(self):
207+
aval = ShapedArray((8, 2), np.int32)
208+
self.assertEqual(str(aval), "int32[8,2]")
209+
210+
aval = ShapedArray((8, 2), np.int32, weak_type=True)
211+
self.assertEqual(str(aval), "~int32[8,2]")
212+
206213
@parameterized.named_parameters(
207214
(str(i), *spec) for i, spec in enumerate(test_specs))
208215
def test_jit(self, f, args):

tests/lax_control_flow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1941,7 +1941,7 @@ def plus_one(p, iter_idx):
19411941
def testScanBodyOutputError(self):
19421942
with self.assertRaisesRegex(
19431943
TypeError,
1944-
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
1944+
re.escape("scan body output must be a pair, got float32[].")):
19451945
lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.))
19461946

19471947
def testScanMetadataError(self):

0 commit comments

Comments
 (0)