Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import hugr._serialization.tys as stys
import hugr.model as model
from hugr.utils import comma_sep_repr, comma_sep_str, ser_it
from hugr.utils import comma_sep_repr, comma_sep_str, comma_sep_str_paren, ser_it

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -430,7 +430,38 @@
return Tuple(*self.variant_rows[0])

def __repr__(self) -> str:
return f"Sum({self.variant_rows})"
if self == Bool:
return "Bool"
elif self == Unit:
return "Unit"
elif all(len(row) == 0 for row in self.variant_rows):
return f"UnitSum({len(self.variant_rows)})"
elif len(self.variant_rows) == 1:
return f"Tuple{tuple(self.variant_rows[0])}"
elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0:
return f"Option({comma_sep_repr(self.variant_rows[1])})"
elif len(self.variant_rows) == 2:
left, right = self.variant_rows
return f"Either(left={left}, right={right})"
else:
return f"Sum({self.variant_rows})"

def __str__(self) -> str:
if self == Bool:
return "Bool"

Check warning on line 451 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L451

Added line #L451 was not covered by tests
elif self == Unit:
return "Unit"

Check warning on line 453 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L453

Added line #L453 was not covered by tests
elif all(len(row) == 0 for row in self.variant_rows):
return f"UnitSum({len(self.variant_rows)})"

Check warning on line 455 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L455

Added line #L455 was not covered by tests
elif len(self.variant_rows) == 1:
return f"Tuple{tuple(self.variant_rows[0])}"
elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0:
return f"Option({comma_sep_str(self.variant_rows[1])})"
elif len(self.variant_rows) == 2:
left, right = self.variant_rows
return f"Either({comma_sep_str_paren(left)}, {comma_sep_str_paren(right)})"
else:
return f"Sum({self.variant_rows})"

def __eq__(self, other: object) -> bool:
return isinstance(other, Sum) and self.variant_rows == other.variant_rows
Expand All @@ -449,7 +480,7 @@
return model.Apply("core.adt", [variants])


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class UnitSum(Sum):
"""Simple :class:`Sum` type with `size` variants of empty rows."""

Expand All @@ -462,18 +493,14 @@
def _to_serial(self) -> stys.UnitSum: # type: ignore[override]
return stys.UnitSum(size=self.size)

def __repr__(self) -> str:
if self == Bool:
return "Bool"
elif self == Unit:
return "Unit"
return f"UnitSum({self.size})"

def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum:
return self

def __str__(self) -> str:
return self.__repr__()


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Tuple(Sum):
"""Product type with `tys` elements. Instances of this type correspond to
:class:`Sum` with a single variant.
Expand All @@ -482,11 +509,8 @@
def __init__(self, *tys: Type):
self.variant_rows = [list(tys)]

def __repr__(self) -> str:
return f"Tuple{tuple(self.variant_rows[0])}"


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Option(Sum):
"""Optional tuple of elements.

Expand All @@ -497,11 +521,8 @@
def __init__(self, *tys: Type):
self.variant_rows = [[], list(tys)]

def __repr__(self) -> str:
return f"Option({comma_sep_repr(self.variant_rows[1])})"


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Either(Sum):
"""Two-variant tuple of elements.

Expand All @@ -514,16 +535,6 @@
def __init__(self, left: Iterable[Type], right: Iterable[Type]):
self.variant_rows = [list(left), list(right)]

def __repr__(self) -> str: # pragma: no cover
left, right = self.variant_rows
return f"Either(left={left}, right={right})"

def __str__(self) -> str:
left, right = self.variant_rows
left_str = left[0] if len(left) == 1 else tuple(left)
right_str = right[0] if len(right) == 1 else tuple(right)
return f"Either({left_str}, {right_str})"


@dataclass(frozen=True)
class Variable(Type):
Expand Down
24 changes: 24 additions & 0 deletions hugr-py/src/hugr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,27 @@
def comma_sep_repr(items: Iterable[T]) -> str:
"""Join items with commas and repr."""
return ", ".join(map(repr, items))


def comma_sep_str_paren(items: Iterable[T]) -> str:
"""Join items with commas and str, wrapping them in parentheses if more than one."""
items = list(items)
if len(items) == 0:
return "()"

Check warning on line 224 in hugr-py/src/hugr/utils.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/utils.py#L224

Added line #L224 was not covered by tests
elif len(items) == 1:
return f"{items[0]}"
else:
return f"({comma_sep_str(items)})"


def comma_sep_repr_paren(items: Iterable[T]) -> str:
"""Join items with commas and repr, wrapping them in parentheses if more
than one.
"""
items = list(items)
if len(items) == 0:
return "()"
elif len(items) == 1:
return f"{items[0]}"

Check warning on line 239 in hugr-py/src/hugr/utils.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/utils.py#L235-L239

Added lines #L235 - L239 were not covered by tests
else:
return f"({comma_sep_repr(items)})"

Check warning on line 241 in hugr-py/src/hugr/utils.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/utils.py#L241

Added line #L241 was not covered by tests
102 changes: 62 additions & 40 deletions hugr-py/src/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class Sum(Value):
"""Sum-of-product value.

Example:
>>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit]]), [TRUE])
Sum(tag=0, typ=Sum([[Bool], [Unit]]), vals=[TRUE])
>>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit], [tys.Bool]]), [TRUE])
Sum(tag=0, typ=Sum([[Bool], [Unit], [Bool]]), vals=[TRUE])
"""

#: Tag identifying the variant.
Expand All @@ -70,6 +70,59 @@ def _to_serial(self) -> sops.SumValue:
vs=ser_it(self.vals),
)

def __repr__(self) -> str:
if self == TRUE:
return "TRUE"
elif self == FALSE:
return "FALSE"
elif self == Unit:
return "Unit"
elif all(len(row) == 0 for row in self.typ.variant_rows):
return f"UnitSum({self.tag}, {self.n_variants})"
elif len(self.typ.variant_rows) == 1:
return f"Tuple({comma_sep_repr(self.vals)})"
elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0:
# Option
if self.tag == 0:
return f"None({comma_sep_str(self.typ.variant_rows[1])})"
else:
return f"Some({comma_sep_repr(self.vals)})"
elif len(self.typ.variant_rows) == 2:
# Either
left_typ, right_typ = self.typ.variant_rows
if self.tag == 0:
return f"Left(vals={self.vals}, right_typ={list(right_typ)})"
else:
return f"Right(left_typ={list(left_typ)}, vals={self.vals})"
else:
return f"Sum(tag={self.tag}, typ={self.typ}, vals={self.vals})"

def __str__(self) -> str:
if self == TRUE:
return "TRUE"
elif self == FALSE:
return "FALSE"
elif self == Unit:
return "Unit"
elif all(len(row) == 0 for row in self.typ.variant_rows):
return f"UnitSum({self.tag}, {self.n_variants})"
elif len(self.typ.variant_rows) == 1:
return f"Tuple({comma_sep_str(self.vals)})"
elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0:
# Option
if self.tag == 0:
return "None"
else:
return f"Some({comma_sep_str(self.vals)})"
elif len(self.typ.variant_rows) == 2:
# Either
if self.tag == 0:
return f"Left({comma_sep_str(self.vals)})"
else:
return f"Right({comma_sep_str(self.vals)})"
else:
return f"Sum({self.tag}, {self.typ}, {self.vals})"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Sum)
Expand Down Expand Up @@ -100,6 +153,7 @@ def to_model(self) -> model.Term:
)


@dataclass(eq=False, repr=False)
class UnitSum(Sum):
"""Simple :class:`Sum` with each variant being an empty row.

Expand All @@ -119,15 +173,6 @@ def __init__(self, tag: int, size: int):
vals=[],
)

def __repr__(self) -> str:
if self == TRUE:
return "TRUE"
if self == FALSE:
return "FALSE"
if self == Unit:
return "Unit"
return f"UnitSum({self.tag}, {self.n_variants})"


def bool_value(b: bool) -> UnitSum:
"""Convert a python bool to a HUGR boolean value.
Expand All @@ -149,7 +194,7 @@ def bool_value(b: bool) -> UnitSum:
FALSE = bool_value(False)


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Tuple(Sum):
"""Tuple or product value, defined by a list of values.
Internally a :class:`Sum` with a single variant row.
Expand Down Expand Up @@ -177,10 +222,10 @@ def _to_serial(self) -> sops.TupleValue: # type: ignore[override]
)

def __repr__(self) -> str:
return f"Tuple({comma_sep_repr(self.vals)})"
return super().__repr__()


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Some(Sum):
"""Optional tuple of value, containing a list of values.

Expand All @@ -199,11 +244,8 @@ def __init__(self, *vals: Value):
tag=1, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list
)

def __repr__(self) -> str:
return f"Some({comma_sep_repr(self.vals)})"


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class None_(Sum):
"""Optional tuple of value, containing no values.

Expand All @@ -219,14 +261,8 @@ class None_(Sum):
def __init__(self, *types: tys.Type):
super().__init__(tag=0, typ=tys.Option(*types), vals=[])

def __repr__(self) -> str:
return f"None({comma_sep_str(self.typ.variant_rows[1])})"

def __str__(self) -> str:
return "None"


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Left(Sum):
"""Left variant of a :class:`tys.Either` type, containing a list of values.

Expand All @@ -248,15 +284,8 @@ def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]):
vals=val_list,
)

def __repr__(self) -> str:
_, right_typ = self.typ.variant_rows
return f"Left(vals={self.vals}, right_typ={list(right_typ)})"

def __str__(self) -> str:
return f"Left({comma_sep_str(self.vals)})"


@dataclass(eq=False)
@dataclass(eq=False, repr=False)
class Right(Sum):
"""Right variant of a :class:`tys.Either` type, containing a list of values.

Expand All @@ -280,13 +309,6 @@ def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]):
vals=val_list,
)

def __repr__(self) -> str:
left_typ, _ = self.typ.variant_rows
return f"Right(left_typ={list(left_typ)}, vals={self.vals})"

def __str__(self) -> str:
return f"Right({comma_sep_str(self.vals)})"


@dataclass
class Function(Value):
Expand Down
14 changes: 9 additions & 5 deletions hugr-py/tests/test_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Sum,
Tuple,
UnitSum,
Value,
bool_value,
)

Expand Down Expand Up @@ -44,9 +43,9 @@ def test_sums():
("value", "string", "repr_str"),
[
(
Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]),
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
Sum(0, tys.Sum([[tys.Bool], [tys.Qubit], [tys.Bool]]), [TRUE]),
"Sum(0, Sum([[Bool], [Qubit], [Bool]]), [TRUE])",
"Sum(tag=0, typ=Sum([[Bool], [Qubit], [Bool]]), vals=[TRUE])",
),
(UnitSum(0, size=1), "Unit", "Unit"),
(UnitSum(0, size=2), "FALSE", "FALSE"),
Expand All @@ -67,10 +66,15 @@ def test_sums():
),
],
)
def test_val_sum_str(value: Value, string: str, repr_str: str):
def test_val_sum_str(value: Sum, string: str, repr_str: str):
assert str(value) == string
assert repr(value) == repr_str

# Make sure the corresponding `Sum` also renders the same
sum_val = Sum(value.tag, value.typ, value.vals)
assert str(sum_val) == string
assert repr(sum_val) == repr_str


def test_val_static_array():
from hugr.std.collections.static_array import StaticArrayVal
Expand Down
Loading