Skip to content

Commit 9e94e6e

Browse files
superbobryjax authors
authored andcommitted
Fixed a typo in min/max Triton lowering rules
PiperOrigin-RevId: 604424404
1 parent b53f757 commit 9e94e6e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

jaxlib/triton/compat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,27 +672,27 @@ def max(x: tensor, y: tensor) -> tensor:
672672
assert x.shape == y.shape
673673
if x.dtype.is_floating():
674674
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
675-
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype)
675+
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.type)
676676
if not x.dtype.is_int():
677677
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
678678
elif x.dtype.is_int_signed():
679-
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype)
679+
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.type)
680680
else:
681-
return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype)
681+
return tensor(arith_dialect.maxui(x.handle, y.handle), x.type)
682682

683683
@staticmethod
684684
def min(x: tensor, y: tensor) -> tensor:
685685
# TODO(slebedev): Consider allowing customizing nan behavior.
686686
assert x.shape == y.shape
687687
if x.dtype.is_floating():
688688
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
689-
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype)
689+
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.type)
690690
if not x.dtype.is_int():
691691
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
692692
elif x.dtype.is_int_signed():
693-
return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype)
693+
return tensor(arith_dialect.minsi(x.handle, y.handle), x.type)
694694
else:
695-
return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype)
695+
return tensor(arith_dialect.minui(x.handle, y.handle), x.type)
696696

697697
sin = libdevice_extern_elementwise({
698698
(float32,): ("__nv_sinf", float32),

0 commit comments

Comments
 (0)