diff --git a/hugr-py/src/hugr/passes/__init__.py b/hugr-py/src/hugr/passes/__init__.py new file mode 100644 index 000000000..8a34628c1 --- /dev/null +++ b/hugr-py/src/hugr/passes/__init__.py @@ -0,0 +1 @@ +"""A hugr-py passes module for hugr transformations.""" diff --git a/hugr-py/src/hugr/passes/composable_pass.py b/hugr-py/src/hugr/passes/composable_pass.py new file mode 100644 index 000000000..90f41fb0e --- /dev/null +++ b/hugr-py/src/hugr/passes/composable_pass.py @@ -0,0 +1,51 @@ +"""A Protocol for a composable pass.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from hugr.hugr.base import Hugr + + +@runtime_checkable +class ComposablePass(Protocol): + """A Protocol which represents a composable Hugr transformation.""" + + def __call__(self, hugr: Hugr) -> None: + """Call the pass to transform a HUGR.""" + ... + + @property + def name(self) -> str: + """Returns the name of the pass.""" + return self.__class__.__name__ + + def then(self, other: ComposablePass) -> ComposablePass: + """Perform another composable pass after this pass.""" + # Provide a default implementation for composing passes. + pass_list = [] + if isinstance(self, ComposedPass): + pass_list.extend(self.passes) + else: + pass_list.append(self) + + if isinstance(other, ComposedPass): + pass_list.extend(other.passes) + else: + pass_list.append(other) + + return ComposedPass(pass_list) + + +@dataclass +class ComposedPass(ComposablePass): + """A sequence of composable passes.""" + + passes: list[ComposablePass] + + def __call__(self, hugr: Hugr): + """Call all of the passes in sequence.""" + for comp_pass in self.passes: + comp_pass(hugr) diff --git a/hugr-py/tests/test_passes.py b/hugr-py/tests/test_passes.py new file mode 100644 index 000000000..542ef1f8b --- /dev/null +++ b/hugr-py/tests/test_passes.py @@ -0,0 +1,26 @@ +from hugr.hugr.base import Hugr +from hugr.passes.composable_pass import ComposablePass, ComposedPass + + +def test_composable_pass() -> None: + class MyDummyPass(ComposablePass): + def __call__(self, hugr: Hugr) -> None: + return self(hugr) + + def then(self, other: ComposablePass) -> ComposablePass: + return ComposedPass([self, other]) + + @property + def name(self) -> str: + return "Dummy" + + dummy = MyDummyPass() + + composed = dummy.then(dummy) + + my_composed_pass = ComposedPass([dummy, dummy]) + + assert my_composed_pass.passes == [dummy, dummy] + assert isinstance(my_composed_pass, ComposablePass) + assert isinstance(composed, ComposablePass) + assert dummy.name == "Dummy"