diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 3eb89696f..9fa654c96 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -8,7 +8,7 @@ import hugr._serialization.tys as stys import hugr.model as model -from hugr.utils import comma_sep_repr, comma_sep_str, ser_it +from hugr.utils import comma_sep_repr, comma_sep_str, comma_sep_str_paren, ser_it if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -430,7 +430,38 @@ def as_tuple(self) -> Tuple: return Tuple(*self.variant_rows[0]) def __repr__(self) -> str: - return f"Sum({self.variant_rows})" + if self == Bool: + return "Bool" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.variant_rows): + return f"UnitSum({len(self.variant_rows)})" + elif len(self.variant_rows) == 1: + return f"Tuple{tuple(self.variant_rows[0])}" + elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: + return f"Option({comma_sep_repr(self.variant_rows[1])})" + elif len(self.variant_rows) == 2: + left, right = self.variant_rows + return f"Either(left={left}, right={right})" + else: + return f"Sum({self.variant_rows})" + + def __str__(self) -> str: + if self == Bool: + return "Bool" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.variant_rows): + return f"UnitSum({len(self.variant_rows)})" + elif len(self.variant_rows) == 1: + return f"Tuple{tuple(self.variant_rows[0])}" + elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: + return f"Option({comma_sep_str(self.variant_rows[1])})" + elif len(self.variant_rows) == 2: + left, right = self.variant_rows + return f"Either({comma_sep_str_paren(left)}, {comma_sep_str_paren(right)})" + else: + return f"Sum({self.variant_rows})" def __eq__(self, other: object) -> bool: return isinstance(other, Sum) and self.variant_rows == other.variant_rows @@ -449,7 +480,7 @@ def to_model(self) -> model.Term: return model.Apply("core.adt", [variants]) -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class UnitSum(Sum): """Simple :class:`Sum` type with `size` variants of empty rows.""" @@ -462,18 +493,14 @@ def __init__(self, size: int): def _to_serial(self) -> stys.UnitSum: # type: ignore[override] return stys.UnitSum(size=self.size) - def __repr__(self) -> str: - if self == Bool: - return "Bool" - elif self == Unit: - return "Unit" - return f"UnitSum({self.size})" - def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum: return self + def __str__(self) -> str: + return self.__repr__() + -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Tuple(Sum): """Product type with `tys` elements. Instances of this type correspond to :class:`Sum` with a single variant. @@ -482,11 +509,8 @@ class Tuple(Sum): def __init__(self, *tys: Type): self.variant_rows = [list(tys)] - def __repr__(self) -> str: - return f"Tuple{tuple(self.variant_rows[0])}" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Option(Sum): """Optional tuple of elements. @@ -497,11 +521,8 @@ class Option(Sum): def __init__(self, *tys: Type): self.variant_rows = [[], list(tys)] - def __repr__(self) -> str: - return f"Option({comma_sep_repr(self.variant_rows[1])})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Either(Sum): """Two-variant tuple of elements. @@ -514,16 +535,6 @@ class Either(Sum): def __init__(self, left: Iterable[Type], right: Iterable[Type]): self.variant_rows = [list(left), list(right)] - def __repr__(self) -> str: # pragma: no cover - left, right = self.variant_rows - return f"Either(left={left}, right={right})" - - def __str__(self) -> str: - left, right = self.variant_rows - left_str = left[0] if len(left) == 1 else tuple(left) - right_str = right[0] if len(right) == 1 else tuple(right) - return f"Either({left_str}, {right_str})" - @dataclass(frozen=True) class Variable(Type): diff --git a/hugr-py/src/hugr/utils.py b/hugr-py/src/hugr/utils.py index 480f3337b..0c6048ec3 100644 --- a/hugr-py/src/hugr/utils.py +++ b/hugr-py/src/hugr/utils.py @@ -215,3 +215,27 @@ def comma_sep_str(items: Iterable[T]) -> str: def comma_sep_repr(items: Iterable[T]) -> str: """Join items with commas and repr.""" return ", ".join(map(repr, items)) + + +def comma_sep_str_paren(items: Iterable[T]) -> str: + """Join items with commas and str, wrapping them in parentheses if more than one.""" + items = list(items) + if len(items) == 0: + return "()" + elif len(items) == 1: + return f"{items[0]}" + else: + return f"({comma_sep_str(items)})" + + +def comma_sep_repr_paren(items: Iterable[T]) -> str: + """Join items with commas and repr, wrapping them in parentheses if more + than one. + """ + items = list(items) + if len(items) == 0: + return "()" + elif len(items) == 1: + return f"{items[0]}" + else: + return f"({comma_sep_repr(items)})" diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 391f865cb..663f64faa 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -45,8 +45,8 @@ class Sum(Value): """Sum-of-product value. Example: - >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit]]), [TRUE]) - Sum(tag=0, typ=Sum([[Bool], [Unit]]), vals=[TRUE]) + >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit], [tys.Bool]]), [TRUE]) + Sum(tag=0, typ=Sum([[Bool], [Unit], [Bool]]), vals=[TRUE]) """ #: Tag identifying the variant. @@ -70,6 +70,59 @@ def _to_serial(self) -> sops.SumValue: vs=ser_it(self.vals), ) + def __repr__(self) -> str: + if self == TRUE: + return "TRUE" + elif self == FALSE: + return "FALSE" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.typ.variant_rows): + return f"UnitSum({self.tag}, {self.n_variants})" + elif len(self.typ.variant_rows) == 1: + return f"Tuple({comma_sep_repr(self.vals)})" + elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: + # Option + if self.tag == 0: + return f"None({comma_sep_str(self.typ.variant_rows[1])})" + else: + return f"Some({comma_sep_repr(self.vals)})" + elif len(self.typ.variant_rows) == 2: + # Either + left_typ, right_typ = self.typ.variant_rows + if self.tag == 0: + return f"Left(vals={self.vals}, right_typ={list(right_typ)})" + else: + return f"Right(left_typ={list(left_typ)}, vals={self.vals})" + else: + return f"Sum(tag={self.tag}, typ={self.typ}, vals={self.vals})" + + def __str__(self) -> str: + if self == TRUE: + return "TRUE" + elif self == FALSE: + return "FALSE" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.typ.variant_rows): + return f"UnitSum({self.tag}, {self.n_variants})" + elif len(self.typ.variant_rows) == 1: + return f"Tuple({comma_sep_str(self.vals)})" + elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: + # Option + if self.tag == 0: + return "None" + else: + return f"Some({comma_sep_str(self.vals)})" + elif len(self.typ.variant_rows) == 2: + # Either + if self.tag == 0: + return f"Left({comma_sep_str(self.vals)})" + else: + return f"Right({comma_sep_str(self.vals)})" + else: + return f"Sum({self.tag}, {self.typ}, {self.vals})" + def __eq__(self, other: object) -> bool: return ( isinstance(other, Sum) @@ -100,6 +153,7 @@ def to_model(self) -> model.Term: ) +@dataclass(eq=False, repr=False) class UnitSum(Sum): """Simple :class:`Sum` with each variant being an empty row. @@ -119,15 +173,6 @@ def __init__(self, tag: int, size: int): vals=[], ) - def __repr__(self) -> str: - if self == TRUE: - return "TRUE" - if self == FALSE: - return "FALSE" - if self == Unit: - return "Unit" - return f"UnitSum({self.tag}, {self.n_variants})" - def bool_value(b: bool) -> UnitSum: """Convert a python bool to a HUGR boolean value. @@ -149,7 +194,7 @@ def bool_value(b: bool) -> UnitSum: FALSE = bool_value(False) -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Tuple(Sum): """Tuple or product value, defined by a list of values. Internally a :class:`Sum` with a single variant row. @@ -177,10 +222,10 @@ def _to_serial(self) -> sops.TupleValue: # type: ignore[override] ) def __repr__(self) -> str: - return f"Tuple({comma_sep_repr(self.vals)})" + return super().__repr__() -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Some(Sum): """Optional tuple of value, containing a list of values. @@ -199,11 +244,8 @@ def __init__(self, *vals: Value): tag=1, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list ) - def __repr__(self) -> str: - return f"Some({comma_sep_repr(self.vals)})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class None_(Sum): """Optional tuple of value, containing no values. @@ -219,14 +261,8 @@ class None_(Sum): def __init__(self, *types: tys.Type): super().__init__(tag=0, typ=tys.Option(*types), vals=[]) - def __repr__(self) -> str: - return f"None({comma_sep_str(self.typ.variant_rows[1])})" - - def __str__(self) -> str: - return "None" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Left(Sum): """Left variant of a :class:`tys.Either` type, containing a list of values. @@ -248,15 +284,8 @@ def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]): vals=val_list, ) - def __repr__(self) -> str: - _, right_typ = self.typ.variant_rows - return f"Left(vals={self.vals}, right_typ={list(right_typ)})" - - def __str__(self) -> str: - return f"Left({comma_sep_str(self.vals)})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Right(Sum): """Right variant of a :class:`tys.Either` type, containing a list of values. @@ -280,13 +309,6 @@ def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]): vals=val_list, ) - def __repr__(self) -> str: - left_typ, _ = self.typ.variant_rows - return f"Right(left_typ={list(left_typ)}, vals={self.vals})" - - def __str__(self) -> str: - return f"Right({comma_sep_str(self.vals)})" - @dataclass class Function(Value): diff --git a/hugr-py/tests/test_val.py b/hugr-py/tests/test_val.py index 11fd1a119..5afab8840 100644 --- a/hugr-py/tests/test_val.py +++ b/hugr-py/tests/test_val.py @@ -14,7 +14,6 @@ Sum, Tuple, UnitSum, - Value, bool_value, ) @@ -44,9 +43,9 @@ def test_sums(): ("value", "string", "repr_str"), [ ( - Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]), - "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", - "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", + Sum(0, tys.Sum([[tys.Bool], [tys.Qubit], [tys.Bool]]), [TRUE]), + "Sum(0, Sum([[Bool], [Qubit], [Bool]]), [TRUE])", + "Sum(tag=0, typ=Sum([[Bool], [Qubit], [Bool]]), vals=[TRUE])", ), (UnitSum(0, size=1), "Unit", "Unit"), (UnitSum(0, size=2), "FALSE", "FALSE"), @@ -67,10 +66,15 @@ def test_sums(): ), ], ) -def test_val_sum_str(value: Value, string: str, repr_str: str): +def test_val_sum_str(value: Sum, string: str, repr_str: str): assert str(value) == string assert repr(value) == repr_str + # Make sure the corresponding `Sum` also renders the same + sum_val = Sum(value.tag, value.typ, value.vals) + assert str(sum_val) == string + assert repr(sum_val) == repr_str + def test_val_static_array(): from hugr.std.collections.static_array import StaticArrayVal