@@ -672,27 +672,27 @@ def max(x: tensor, y: tensor) -> tensor:
672672 assert x .shape == y .shape
673673 if x .dtype .is_floating ():
674674 # TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
675- return tensor (arith_dialect .maxnumf (x .handle , y .handle ), x .dtype )
675+ return tensor (arith_dialect .maxnumf (x .handle , y .handle ), x .type )
676676 if not x .dtype .is_int ():
677677 raise NotImplementedError (f"unsupported dtypes: { x .dtype } and { y .dtype } " )
678678 elif x .dtype .is_int_signed ():
679- return tensor (arith_dialect .maxsi (x .handle , y .handle ), x .dtype )
679+ return tensor (arith_dialect .maxsi (x .handle , y .handle ), x .type )
680680 else :
681- return tensor (arith_dialect .maxui (x .handle , y .handle ), x .dtype )
681+ return tensor (arith_dialect .maxui (x .handle , y .handle ), x .type )
682682
683683 @staticmethod
684684 def min (x : tensor , y : tensor ) -> tensor :
685685 # TODO(slebedev): Consider allowing customizing nan behavior.
686686 assert x .shape == y .shape
687687 if x .dtype .is_floating ():
688688 # TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
689- return tensor (arith_dialect .minnumf (x .handle , y .handle ), x .dtype )
689+ return tensor (arith_dialect .minnumf (x .handle , y .handle ), x .type )
690690 if not x .dtype .is_int ():
691691 raise NotImplementedError (f"unsupported dtypes: { x .dtype } and { y .dtype } " )
692692 elif x .dtype .is_int_signed ():
693- return tensor (arith_dialect .minsi (x .handle , y .handle ), x .dtype )
693+ return tensor (arith_dialect .minsi (x .handle , y .handle ), x .type )
694694 else :
695- return tensor (arith_dialect .minui (x .handle , y .handle ), x .dtype )
695+ return tensor (arith_dialect .minui (x .handle , y .handle ), x .type )
696696
697697 sin = libdevice_extern_elementwise ({
698698 (float32 ,): ("__nv_sinf" , float32 ),
0 commit comments