66from __future__ import annotations
77
88from copy import deepcopy
9- from dataclasses import dataclass
10- from typing import TYPE_CHECKING , Protocol , runtime_checkable
9+ from dataclasses import dataclass , field
10+ from typing import TYPE_CHECKING , Any , Protocol , runtime_checkable
1111
1212if TYPE_CHECKING :
1313 from collections .abc import Callable
1414
1515 from hugr .hugr .base import Hugr
1616
1717
18+ # Type alias for a pass name
19+ PassName = str
20+
21+
1822@runtime_checkable
1923class ComposablePass (Protocol ):
2024 """A Protocol which represents a composable Hugr transformation."""
2125
2226 def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
23- """Call the pass to transform a HUGR.
27+ """Call the pass to transform a HUGR, returning a Hugr."""
28+ return self .run (hugr , inplace = inplace ).hugr
29+
30+ def run (self , hugr : Hugr , * , inplace : bool = True ) -> PassResult :
31+ """Run the pass to transform a HUGR, returning a PassResult.
2432
25- See :func:`_impl_pass_call ` for a helper function to implement this method.
33+ See :func:`implement_pass_run ` for a helper function to implement this method.
2634 """
2735
2836 @property
29- def name (self ) -> str :
37+ def name (self ) -> PassName :
3038 """Returns the name of the pass."""
3139 return self .__class__ .__name__
3240
3341 def then (self , other : ComposablePass ) -> ComposablePass :
3442 """Perform another composable pass after this pass."""
35- # Provide a default implementation for composing passes.
36- pass_list = []
37- if isinstance (self , ComposedPass ):
38- pass_list .extend (self .passes )
39- else :
40- pass_list .append (self )
43+ return ComposedPass (self , other )
4144
42- if isinstance (other , ComposedPass ):
43- pass_list .extend (other .passes )
44- else :
45- pass_list .append (other )
4645
47- return ComposedPass (pass_list )
48-
49-
50- def impl_pass_call (
46+ def implement_pass_run (
47+ composable_pass : ComposablePass ,
5148 * ,
5249 hugr : Hugr ,
5350 inplace : bool ,
54- inplace_call : Callable [[Hugr ], None ] | None = None ,
55- copy_call : Callable [[Hugr ], Hugr ] | None = None ,
56- ) -> Hugr :
57- """Helper function to implement a ComposablePass.__call__ method, given an
58- inplace or copy-returning pass methods .
51+ inplace_call : Callable [[Hugr ], PassResult ] | None = None ,
52+ copy_call : Callable [[Hugr ], PassResult ] | None = None ,
53+ ) -> PassResult :
54+ """Helper function to implement a ComposablePass.run method, given an
55+ inplace or copy-returning pass method .
5956
6057 At least one of the `inplace_call` or `copy_call` arguments must be provided.
6158
59+ :param composable_pass: The pass being run. Used for error messages.
6260 :param hugr: The Hugr to apply the pass to.
6361 :param inplace: Whether to apply the pass inplace.
6462 :param inplace_call: The method to apply the pass inplace.
6563 :param copy_call: The method to apply the pass by copying the Hugr.
66- :return: The transformed Hugr.
64+ :return: The result of the pass application.
65+ :raises ValueError: If neither `inplace_call` nor `copy_call` is provided.
6766 """
68- if inplace and inplace_call is not None :
69- inplace_call (hugr )
70- return hugr
71- elif inplace and copy_call is not None :
72- new_hugr = copy_call (hugr )
73- hugr ._overwrite_hugr (new_hugr )
74- return hugr
75- elif not inplace and copy_call is not None :
76- return copy_call (hugr )
77- elif not inplace and inplace_call is not None :
78- new_hugr = deepcopy (hugr )
79- inplace_call (new_hugr )
80- return new_hugr
81- else :
82- msg = "Pass must implement at least an inplace or copy run method"
83- raise ValueError (msg )
67+ if inplace :
68+ if inplace_call is not None :
69+ return inplace_call (hugr )
70+ elif copy_call is not None :
71+ pass_result = copy_call (hugr )
72+ pass_result .hugr = hugr
73+ if pass_result .modified :
74+ hugr ._overwrite_hugr (pass_result .hugr )
75+ pass_result .inplace = True
76+ return pass_result
77+ elif not inplace :
78+ if copy_call is not None :
79+ return copy_call (hugr )
80+ elif inplace_call is not None :
81+ new_hugr = deepcopy (hugr )
82+ pass_result = inplace_call (new_hugr )
83+ pass_result .inplace = False
84+ return pass_result
85+
86+ msg = (
87+ f"{ composable_pass .name } needs to implement at least "
88+ + "an inplace or copy run method"
89+ )
90+ raise ValueError (msg )
8491
8592
8693@dataclass
@@ -89,24 +96,92 @@ class ComposedPass(ComposablePass):
8996
9097 passes : list [ComposablePass ]
9198
92- def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
93- def apply (hugr : Hugr ) -> Hugr :
94- result_hugr = hugr
99+ def __init__ (self , * passes : ComposablePass ) -> None :
100+ self .passes = []
101+ for composable_pass in passes :
102+ if isinstance (composable_pass , ComposedPass ):
103+ self .passes .extend (composable_pass .passes )
104+ else :
105+ self .passes .append (composable_pass )
106+
107+ def run (self , hugr : Hugr , * , inplace : bool = True ) -> PassResult :
108+ def apply (inplace : bool , hugr : Hugr ) -> PassResult :
109+ pass_result = PassResult (hugr = hugr , inplace = inplace )
95110 for comp_pass in self .passes :
96- result_hugr = comp_pass (result_hugr , inplace = False )
97- return result_hugr
111+ new_result = comp_pass .run (pass_result .hugr , inplace = inplace )
112+ pass_result = pass_result .then (new_result )
113+ return pass_result
98114
99- def apply_inplace (hugr : Hugr ) -> None :
100- for comp_pass in self .passes :
101- comp_pass (hugr , inplace = True )
102-
103- return impl_pass_call (
115+ return implement_pass_run (
116+ self ,
104117 hugr = hugr ,
105118 inplace = inplace ,
106- inplace_call = apply_inplace ,
107- copy_call = apply ,
119+ inplace_call = lambda hugr : apply ( True , hugr ) ,
120+ copy_call = lambda hugr : apply ( False , hugr ) ,
108121 )
109122
110123 @property
111- def name (self ) -> str :
112- return f"Composed({ ', ' .join (pass_ .name for pass_ in self .passes ) } )"
124+ def name (self ) -> PassName :
125+ names = [composable_pass .name for composable_pass in self .passes ]
126+ return f"Composed({ ', ' .join (names ) } )"
127+
128+
129+ @dataclass
130+ class PassResult :
131+ """The result of a series of composed passes applied to a HUGR.
132+
133+ Includes a flag indicating whether the passes modified the HUGR, and an
134+ arbitrary result object for each pass.
135+
136+ :attr hugr: The transformed Hugr.
137+ :attr inplace: Whether the pass was applied inplace.
138+ If this is `True`, `hugr` will be the same object passed as input.
139+ If this is `False`, `hugr` will be an independent copy of the original Hugr.
140+ :attr modified: Whether the pass made changes to the HUGR.
141+ If `False`, `hugr` will have the same contents as the original Hugr.
142+ If `True`, no guarantees are made about the contents of `hugr`.
143+ :attr results: The result of each applied pass, as a tuple of the pass name
144+ and the result.
145+ """
146+
147+ hugr : Hugr
148+ inplace : bool = False
149+ modified : bool = False
150+ results : list [tuple [PassName , Any ]] = field (default_factory = list )
151+
152+ @classmethod
153+ def for_pass (
154+ cls ,
155+ composable_pass : ComposablePass ,
156+ hugr : Hugr ,
157+ * ,
158+ result : Any ,
159+ inplace : bool ,
160+ modified : bool = True ,
161+ ) -> PassResult :
162+ """Create a new PassResult after a pass application.
163+
164+ :param hugr: The Hugr that was transformed.
165+ :param composable_pass: The pass that was applied.
166+ :param result: The result of the pass application.
167+ :param inplace: Whether the pass was applied inplace.
168+ :param modified: Whether the pass modified the HUGR.
169+ """
170+ return cls (
171+ hugr = hugr ,
172+ inplace = inplace ,
173+ modified = modified ,
174+ results = [(composable_pass .name , result )],
175+ )
176+
177+ def then (self , other : PassResult ) -> PassResult :
178+ """Extend the PassResult with the results of another PassResult.
179+
180+ Keeps the hugr returned by the last pass.
181+ """
182+ return PassResult (
183+ hugr = other .hugr ,
184+ inplace = self .inplace and other .inplace ,
185+ modified = self .modified or other .modified ,
186+ results = self .results + other .results ,
187+ )
0 commit comments