Skip to content

Commit b8df28e

Browse files
aborgna-qCalMacCQacl-cqc
authored
feat: Result type for ComposablePasses (#2703)
Includes an idea for simplifying the protocol's `_apply`/`_apply_inline` from #2697 by providing a helper function instead (859c811). --------- Co-authored-by: Callum Macpherson <[email protected]> Co-authored-by: Callum Macpherson <[email protected]> Co-authored-by: Alan Lawrence <[email protected]>
1 parent dbf8c8e commit b8df28e

File tree

2 files changed

+222
-72
lines changed

2 files changed

+222
-72
lines changed

hugr-py/src/hugr/passes/_composable_pass.py

Lines changed: 130 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,88 @@
66
from __future__ import annotations
77

88
from copy import deepcopy
9-
from dataclasses import dataclass
10-
from typing import TYPE_CHECKING, Protocol, runtime_checkable
9+
from dataclasses import dataclass, field
10+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Callable
1414

1515
from hugr.hugr.base import Hugr
1616

1717

18+
# Type alias for a pass name
19+
PassName = str
20+
21+
1822
@runtime_checkable
1923
class ComposablePass(Protocol):
2024
"""A Protocol which represents a composable Hugr transformation."""
2125

2226
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
23-
"""Call the pass to transform a HUGR.
27+
"""Call the pass to transform a HUGR, returning a Hugr."""
28+
return self.run(hugr, inplace=inplace).hugr
29+
30+
def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult:
31+
"""Run the pass to transform a HUGR, returning a PassResult.
2432
25-
See :func:`_impl_pass_call` for a helper function to implement this method.
33+
See :func:`implement_pass_run` for a helper function to implement this method.
2634
"""
2735

2836
@property
29-
def name(self) -> str:
37+
def name(self) -> PassName:
3038
"""Returns the name of the pass."""
3139
return self.__class__.__name__
3240

3341
def then(self, other: ComposablePass) -> ComposablePass:
3442
"""Perform another composable pass after this pass."""
35-
# Provide a default implementation for composing passes.
36-
pass_list = []
37-
if isinstance(self, ComposedPass):
38-
pass_list.extend(self.passes)
39-
else:
40-
pass_list.append(self)
43+
return ComposedPass(self, other)
4144

42-
if isinstance(other, ComposedPass):
43-
pass_list.extend(other.passes)
44-
else:
45-
pass_list.append(other)
4645

47-
return ComposedPass(pass_list)
48-
49-
50-
def impl_pass_call(
46+
def implement_pass_run(
47+
composable_pass: ComposablePass,
5148
*,
5249
hugr: Hugr,
5350
inplace: bool,
54-
inplace_call: Callable[[Hugr], None] | None = None,
55-
copy_call: Callable[[Hugr], Hugr] | None = None,
56-
) -> Hugr:
57-
"""Helper function to implement a ComposablePass.__call__ method, given an
58-
inplace or copy-returning pass methods.
51+
inplace_call: Callable[[Hugr], PassResult] | None = None,
52+
copy_call: Callable[[Hugr], PassResult] | None = None,
53+
) -> PassResult:
54+
"""Helper function to implement a ComposablePass.run method, given an
55+
inplace or copy-returning pass method.
5956
6057
At least one of the `inplace_call` or `copy_call` arguments must be provided.
6158
59+
:param composable_pass: The pass being run. Used for error messages.
6260
:param hugr: The Hugr to apply the pass to.
6361
:param inplace: Whether to apply the pass inplace.
6462
:param inplace_call: The method to apply the pass inplace.
6563
:param copy_call: The method to apply the pass by copying the Hugr.
66-
:return: The transformed Hugr.
64+
:return: The result of the pass application.
65+
:raises ValueError: If neither `inplace_call` nor `copy_call` is provided.
6766
"""
68-
if inplace and inplace_call is not None:
69-
inplace_call(hugr)
70-
return hugr
71-
elif inplace and copy_call is not None:
72-
new_hugr = copy_call(hugr)
73-
hugr._overwrite_hugr(new_hugr)
74-
return hugr
75-
elif not inplace and copy_call is not None:
76-
return copy_call(hugr)
77-
elif not inplace and inplace_call is not None:
78-
new_hugr = deepcopy(hugr)
79-
inplace_call(new_hugr)
80-
return new_hugr
81-
else:
82-
msg = "Pass must implement at least an inplace or copy run method"
83-
raise ValueError(msg)
67+
if inplace:
68+
if inplace_call is not None:
69+
return inplace_call(hugr)
70+
elif copy_call is not None:
71+
pass_result = copy_call(hugr)
72+
pass_result.hugr = hugr
73+
if pass_result.modified:
74+
hugr._overwrite_hugr(pass_result.hugr)
75+
pass_result.inplace = True
76+
return pass_result
77+
elif not inplace:
78+
if copy_call is not None:
79+
return copy_call(hugr)
80+
elif inplace_call is not None:
81+
new_hugr = deepcopy(hugr)
82+
pass_result = inplace_call(new_hugr)
83+
pass_result.inplace = False
84+
return pass_result
85+
86+
msg = (
87+
f"{composable_pass.name} needs to implement at least "
88+
+ "an inplace or copy run method"
89+
)
90+
raise ValueError(msg)
8491

8592

8693
@dataclass
@@ -89,24 +96,92 @@ class ComposedPass(ComposablePass):
8996

9097
passes: list[ComposablePass]
9198

92-
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
93-
def apply(hugr: Hugr) -> Hugr:
94-
result_hugr = hugr
99+
def __init__(self, *passes: ComposablePass) -> None:
100+
self.passes = []
101+
for composable_pass in passes:
102+
if isinstance(composable_pass, ComposedPass):
103+
self.passes.extend(composable_pass.passes)
104+
else:
105+
self.passes.append(composable_pass)
106+
107+
def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult:
108+
def apply(inplace: bool, hugr: Hugr) -> PassResult:
109+
pass_result = PassResult(hugr=hugr, inplace=inplace)
95110
for comp_pass in self.passes:
96-
result_hugr = comp_pass(result_hugr, inplace=False)
97-
return result_hugr
111+
new_result = comp_pass.run(pass_result.hugr, inplace=inplace)
112+
pass_result = pass_result.then(new_result)
113+
return pass_result
98114

99-
def apply_inplace(hugr: Hugr) -> None:
100-
for comp_pass in self.passes:
101-
comp_pass(hugr, inplace=True)
102-
103-
return impl_pass_call(
115+
return implement_pass_run(
116+
self,
104117
hugr=hugr,
105118
inplace=inplace,
106-
inplace_call=apply_inplace,
107-
copy_call=apply,
119+
inplace_call=lambda hugr: apply(True, hugr),
120+
copy_call=lambda hugr: apply(False, hugr),
108121
)
109122

110123
@property
111-
def name(self) -> str:
112-
return f"Composed({ ', '.join(pass_.name for pass_ in self.passes) })"
124+
def name(self) -> PassName:
125+
names = [composable_pass.name for composable_pass in self.passes]
126+
return f"Composed({ ', '.join(names) })"
127+
128+
129+
@dataclass
130+
class PassResult:
131+
"""The result of a series of composed passes applied to a HUGR.
132+
133+
Includes a flag indicating whether the passes modified the HUGR, and an
134+
arbitrary result object for each pass.
135+
136+
:attr hugr: The transformed Hugr.
137+
:attr inplace: Whether the pass was applied inplace.
138+
If this is `True`, `hugr` will be the same object passed as input.
139+
If this is `False`, `hugr` will be an independent copy of the original Hugr.
140+
:attr modified: Whether the pass made changes to the HUGR.
141+
If `False`, `hugr` will have the same contents as the original Hugr.
142+
If `True`, no guarantees are made about the contents of `hugr`.
143+
:attr results: The result of each applied pass, as a tuple of the pass name
144+
and the result.
145+
"""
146+
147+
hugr: Hugr
148+
inplace: bool = False
149+
modified: bool = False
150+
results: list[tuple[PassName, Any]] = field(default_factory=list)
151+
152+
@classmethod
153+
def for_pass(
154+
cls,
155+
composable_pass: ComposablePass,
156+
hugr: Hugr,
157+
*,
158+
result: Any,
159+
inplace: bool,
160+
modified: bool = True,
161+
) -> PassResult:
162+
"""Create a new PassResult after a pass application.
163+
164+
:param hugr: The Hugr that was transformed.
165+
:param composable_pass: The pass that was applied.
166+
:param result: The result of the pass application.
167+
:param inplace: Whether the pass was applied inplace.
168+
:param modified: Whether the pass modified the HUGR.
169+
"""
170+
return cls(
171+
hugr=hugr,
172+
inplace=inplace,
173+
modified=modified,
174+
results=[(composable_pass.name, result)],
175+
)
176+
177+
def then(self, other: PassResult) -> PassResult:
178+
"""Extend the PassResult with the results of another PassResult.
179+
180+
Keeps the hugr returned by the last pass.
181+
"""
182+
return PassResult(
183+
hugr=other.hugr,
184+
inplace=self.inplace and other.inplace,
185+
modified=self.modified or other.modified,
186+
results=self.results + other.results,
187+
)

hugr-py/tests/test_passes.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,105 @@
1+
from copy import deepcopy
2+
3+
import pytest
4+
15
from hugr.hugr.base import Hugr
2-
from hugr.passes._composable_pass import ComposablePass, ComposedPass, impl_pass_call
6+
from hugr.passes._composable_pass import (
7+
ComposablePass,
8+
ComposedPass,
9+
PassResult,
10+
implement_pass_run,
11+
)
312

413

514
def test_composable_pass() -> None:
6-
class MyDummyPass(ComposablePass):
7-
def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr:
8-
return impl_pass_call(
15+
class DummyInlinePass(ComposablePass):
16+
def run(self, hugr: Hugr, inplace: bool = True) -> PassResult:
17+
return implement_pass_run(
18+
self,
919
hugr=hugr,
1020
inplace=inplace,
11-
inplace_call=lambda hugr: None,
21+
inplace_call=lambda hugr: PassResult.for_pass(
22+
self,
23+
hugr,
24+
result=None,
25+
inplace=True,
26+
# Say that we modified the HUGR even though we didn't
27+
modified=True,
28+
),
1229
)
1330

14-
dummy = MyDummyPass()
31+
class DummyCopyPass(ComposablePass):
32+
def run(self, hugr: Hugr, inplace: bool = True) -> PassResult:
33+
return implement_pass_run(
34+
self,
35+
hugr=hugr,
36+
inplace=inplace,
37+
copy_call=lambda hugr: PassResult.for_pass(
38+
self,
39+
deepcopy(hugr),
40+
result=None,
41+
inplace=False,
42+
# Say that we modified the HUGR even though we didn't
43+
modified=True,
44+
),
45+
)
1546

16-
composed_dummies = dummy.then(dummy)
47+
dummy_inline = DummyInlinePass()
48+
dummy_copy = DummyCopyPass()
1749

18-
my_composed_pass = ComposedPass([dummy, dummy])
19-
assert my_composed_pass.passes == [dummy, dummy]
50+
composed_dummies = dummy_inline.then(dummy_copy)
51+
assert isinstance(composed_dummies, ComposedPass)
2052

21-
assert isinstance(composed_dummies, ComposablePass)
22-
assert composed_dummies == my_composed_pass
53+
assert dummy_inline.name == "DummyInlinePass"
54+
assert dummy_copy.name == "DummyCopyPass"
55+
assert composed_dummies.name == "Composed(DummyInlinePass, DummyCopyPass)"
56+
assert composed_dummies.then(dummy_inline).then(composed_dummies).name == (
57+
"Composed("
58+
+ "DummyInlinePass, DummyCopyPass, "
59+
+ "DummyInlinePass, "
60+
+ "DummyInlinePass, DummyCopyPass)"
61+
)
2362

24-
assert dummy.name == "MyDummyPass"
25-
assert composed_dummies.name == "Composed(MyDummyPass, MyDummyPass)"
63+
# Apply the passes
64+
hugr: Hugr = Hugr()
65+
new_hugr = composed_dummies(hugr, inplace=False)
66+
assert hugr == new_hugr
67+
assert new_hugr is not hugr
2668

27-
assert (
28-
composed_dummies.then(my_composed_pass).name
29-
== "Composed(MyDummyPass, MyDummyPass, MyDummyPass, MyDummyPass)"
30-
)
69+
# Verify the pass results
70+
hugr = Hugr()
71+
inplace_result = composed_dummies.run(hugr, inplace=True)
72+
assert inplace_result.modified
73+
assert inplace_result.inplace
74+
assert inplace_result.results == [
75+
("DummyInlinePass", None),
76+
("DummyCopyPass", None),
77+
]
78+
assert inplace_result.hugr is hugr
79+
80+
hugr = Hugr()
81+
copy_result = composed_dummies.run(hugr, inplace=False)
82+
assert copy_result.modified
83+
assert not copy_result.inplace
84+
assert copy_result.results == [
85+
("DummyInlinePass", None),
86+
("DummyCopyPass", None),
87+
]
88+
assert copy_result.hugr is not hugr
89+
90+
91+
def test_invalid_composable_pass() -> None:
92+
class DummyInvalidPass(ComposablePass):
93+
def run(self, hugr: Hugr, inplace: bool = True) -> PassResult:
94+
return implement_pass_run(
95+
self,
96+
hugr=hugr,
97+
inplace=inplace,
98+
)
99+
100+
dummy_invalid = DummyInvalidPass()
101+
with pytest.raises(
102+
ValueError,
103+
match="DummyInvalidPass needs to implement at least an inplace or copy run method", # noqa: E501
104+
):
105+
dummy_invalid.run(Hugr())

0 commit comments

Comments
 (0)