From a7b27180cbb85490c09f8e24f46eeb4d5fd5eb21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:01:10 +0100 Subject: [PATCH] fix: Sum value equality. Add unit tests (#1484) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This was supposed to be part of #1481, but pushed it to the branch that depended on it instead 🤦 - Adds the string/repr unit tests suggested by https://github.com/CQCL/hugr/pull/1481#pullrequestreview-2268976031 - Tests—and fixes—equality comparation between Sum values. --- hugr-py/src/hugr/tys.py | 7 ----- hugr-py/src/hugr/val.py | 38 +++++++++----------------- hugr-py/tests/test_tys.py | 34 ++++++++++++++++++++++- hugr-py/tests/test_val.py | 57 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 33 deletions(-) create mode 100644 hugr-py/tests/test_val.py diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 045198b74..7f274f431 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -328,13 +328,6 @@ class Either(Sum): In fallible contexts, the Left variant is used to represent success, and the Right variant is used to represent failure. - - Example: - >>> either = Either([Bool, Bool], [Bool]) - >>> either - Either(left=[Bool, Bool], right=[Bool]) - >>> str(either) - 'Either((Bool, Bool), Bool)' """ def __init__(self, left: Iterable[Type], right: Iterable[Type]): diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 7757636b1..f9735c6cf 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -67,6 +67,14 @@ def _to_serial(self) -> sops.SumValue: vs=ser_it(self.vals), ) + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Sum) + and self.tag == other.tag + and self.typ == other.typ + and self.vals == other.vals + ) + class UnitSum(Sum): """Simple :class:`Sum` with each variant being an empty row. @@ -117,7 +125,7 @@ def bool_value(b: bool) -> UnitSum: FALSE = bool_value(False) -@dataclass +@dataclass(eq=False) class Tuple(Sum): """Tuple or product value, defined by a list of values. Internally a :class:`Sum` with a single variant row. @@ -131,9 +139,6 @@ class Tuple(Sum): """ - #: The values of this tuple. - vals: list[Value] - def __init__(self, *vals: Value): val_list = list(vals) super().__init__( @@ -151,14 +156,12 @@ def __repr__(self) -> str: return f"Tuple({', '.join(map(repr, self.vals))})" -@dataclass +@dataclass(eq=False) class Some(Sum): """Optional tuple of value, containing a list of values. Example: >>> some = Some(TRUE, FALSE) - >>> some - Some(TRUE, FALSE) >>> str(some) 'Some(TRUE, FALSE)' >>> some.type_() @@ -166,9 +169,6 @@ class Some(Sum): """ - #: The values of this tuple. - vals: list[Value] - def __init__(self, *vals: Value): val_list = list(vals) super().__init__( @@ -179,14 +179,12 @@ def __repr__(self) -> str: return f"Some({', '.join(map(repr, self.vals))})" -@dataclass +@dataclass(eq=False) class None_(Sum): """Optional tuple of value, containing no values. Example: >>> none = None_(tys.Bool) - >>> none - None(Bool) >>> str(none) 'None' >>> none.type_() @@ -204,7 +202,7 @@ def __str__(self) -> str: return "None" -@dataclass +@dataclass(eq=False) class Left(Sum): """Left variant of a :class:`tys.Either` type, containing a list of values. @@ -212,17 +210,12 @@ class Left(Sum): Example: >>> left = Left([TRUE, FALSE], [tys.Bool]) - >>> left - Left(vals=[TRUE, FALSE], right_typ=[Bool]) >>> str(left) 'Left(TRUE, FALSE)' >>> str(left.type_()) 'Either((Bool, Bool), Bool)' """ - #: The values of this tuple. - vals: list[Value] - def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]): val_list = list(vals) super().__init__( @@ -240,7 +233,7 @@ def __str__(self) -> str: return f"Left({vals_str})" -@dataclass +@dataclass(eq=False) class Right(Sum): """Right variant of a :class:`tys.Either` type, containing a list of values. @@ -250,17 +243,12 @@ class Right(Sum): Example: >>> right = Right([tys.Bool, tys.Bool, tys.Bool], [TRUE, FALSE]) - >>> right - Right(left_typ=[Bool, Bool, Bool], vals=[TRUE, FALSE]) >>> str(right) 'Right(TRUE, FALSE)' >>> str(right.type_()) 'Either((Bool, Bool, Bool), (Bool, Bool))' """ - #: The values of this tuple. - vals: list[Value] - def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]): val_list = list(vals) super().__init__( diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index 689234426..8fd98e6e1 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -1,6 +1,8 @@ from __future__ import annotations -from hugr.tys import Bool, Qubit, Sum, Tuple, UnitSum +import pytest + +from hugr.tys import Bool, Either, Option, Qubit, Sum, Tuple, Type, UnitSum def test_sums(): @@ -8,7 +10,37 @@ def test_sums(): assert Tuple(Bool, Qubit) == Sum([[Bool, Qubit]]) assert Sum([[Bool, Qubit]]).as_tuple() == Sum([[Bool, Qubit]]) + assert Sum([[Bool, Qubit], []]) == Option(Bool, Qubit) + assert Sum([[Bool, Qubit], []]) == Either([Bool, Qubit], []) + assert Option(Bool, Qubit) == Either([Bool, Qubit], []) + assert Sum([[Qubit], [Bool]]) == Either([Qubit], [Bool]) + assert Tuple() == Sum([[]]) assert UnitSum(0) == Sum([]) assert UnitSum(1) == Tuple() assert UnitSum(4) == Sum([[], [], [], []]) + + +@pytest.mark.parametrize( + ("ty", "string", "repr_str"), + [ + ( + Sum([[Bool], [Qubit], [Qubit, Bool]]), + "Sum([[Bool], [Qubit], [Qubit, Bool]])", + "Sum([[Bool], [Qubit], [Qubit, Bool]])", + ), + (UnitSum(1), "Unit", "Unit"), + (UnitSum(2), "Bool", "Bool"), + (UnitSum(3), "UnitSum(3)", "UnitSum(3)"), + (Tuple(Bool, Qubit), "Tuple(Bool, Qubit)", "Tuple(Bool, Qubit)"), + (Option(Bool, Qubit), "Option(Bool, Qubit)", "Option(Bool, Qubit)"), + ( + Either([Bool, Qubit], [Bool]), + "Either((Bool, Qubit), Bool)", + "Either(left=[Bool, Qubit], right=[Bool])", + ), + ], +) +def test_tys_sum_str(ty: Type, string: str, repr_str: str): + assert str(ty) == string + assert repr(ty) == repr_str diff --git a/hugr-py/tests/test_val.py b/hugr-py/tests/test_val.py new file mode 100644 index 000000000..364b9e3df --- /dev/null +++ b/hugr-py/tests/test_val.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from hugr import tys +from hugr.val import FALSE, TRUE, Left, None_, Right, Some, Sum, Tuple, UnitSum, Value + + +def test_sums(): + assert Sum(0, tys.Tuple(), []) == Tuple() + assert Sum(0, tys.Tuple(tys.Bool, tys.Bool), [TRUE, FALSE]) == Tuple(TRUE, FALSE) + + ty = tys.Sum([[tys.Bool, tys.Bool], []]) + assert Sum(0, ty, [TRUE, FALSE]) == Some(TRUE, FALSE) + assert Sum(0, ty, [TRUE, FALSE]) == Left([TRUE, FALSE], []) + assert Sum(1, ty, []) == None_(tys.Bool, tys.Bool) + assert Sum(1, ty, []) == Right([tys.Bool, tys.Bool], []) + + ty = tys.Sum([[tys.Bool], [tys.Bool]]) + assert Sum(0, ty, [TRUE]) == Left([TRUE], [tys.Bool]) + assert Sum(1, ty, [FALSE]) == Right([tys.Bool], [FALSE]) + + assert Tuple() == Sum(0, tys.Tuple(), []) + assert UnitSum(0, size=1) == Tuple() + assert UnitSum(2, size=4) == Sum(2, tys.UnitSum(size=4), []) + + +@pytest.mark.parametrize( + ("value", "string", "repr_str"), + [ + ( + Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]), + "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", + "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", + ), + (UnitSum(0, size=1), "Unit", "Unit"), + (UnitSum(0, size=2), "FALSE", "FALSE"), + (UnitSum(1, size=2), "TRUE", "TRUE"), + (UnitSum(2, size=5), "UnitSum(2, 5)", "UnitSum(2, 5)"), + (Tuple(TRUE, FALSE), "Tuple(TRUE, FALSE)", "Tuple(TRUE, FALSE)"), + (Some(TRUE, FALSE), "Some(TRUE, FALSE)", "Some(TRUE, FALSE)"), + (None_(tys.Bool, tys.Bool), "None", "None(Bool, Bool)"), + ( + Left([TRUE, FALSE], [tys.Bool]), + "Left(TRUE, FALSE)", + "Left(vals=[TRUE, FALSE], right_typ=[Bool])", + ), + ( + Right([tys.Bool, tys.Bool], [FALSE]), + "Right(FALSE)", + "Right(left_typ=[Bool, Bool], vals=[FALSE])", + ), + ], +) +def test_val_sum_str(value: Value, string: str, repr_str: str): + assert str(value) == string + assert repr(value) == repr_str