diff --git a/pennylane/compiler/python_compiler/dialects/__init__.py b/pennylane/compiler/python_compiler/dialects/__init__.py index a562a50e403..eb2e756b57e 100644 --- a/pennylane/compiler/python_compiler/dialects/__init__.py +++ b/pennylane/compiler/python_compiler/dialects/__init__.py @@ -14,9 +14,11 @@ """This submodule contains xDSL dialects for the Python compiler.""" +from .catalyst import Catalyst from .mbqc import MBQC from .quantum import Quantum -from .catalyst import Catalyst from .qec import QEC +from .transform import Transform + -__all__ = ["Catalyst", "MBQC", "Quantum", "QEC"] +__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "Transform"] diff --git a/pennylane/compiler/python_compiler/dialects/transform.py b/pennylane/compiler/python_compiler/dialects/transform.py new file mode 100644 index 00000000000..1653e557ba3 --- /dev/null +++ b/pennylane/compiler/python_compiler/dialects/transform.py @@ -0,0 +1,124 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains an updated version of the transform dialect. +As of the time of writing, xDSL uses the MLIR released with LLVM's +version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916 +will be updating MLIR where the transform dialect has the +`apply_registered_pass` operation re-defined. + +See the following changelog on the above PR + + Things related to transform.apply_registered_pass op: + + It now takes in a dynamic_options + + [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as + a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static. + https://github.com/llvm/llvm-project/pull/142683 + + The options it takes in are now dictionaries instead of strings + [MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159 + https://github.com/llvm/llvm-project/pull/143159 + +This file will re-define the apply_registered_pass operation in xDSL +and the transform dialect. + +Once xDSL moves to a newer version of MLIR, these changes should +be contributed upstream. +""" + +from xdsl.dialects.builtin import Dialect + +# pylint: disable=unused-wildcard-import,wildcard-import,undefined-variable,too-few-public-methods +from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp +from xdsl.dialects.transform import ( + DictionaryAttr, + StringAttr, +) +from xdsl.dialects.transform import Transform as xTransform +from xdsl.dialects.transform import ( + TransformHandleType, + irdl_op_definition, + operand_def, + prop_def, + result_def, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import IRDLOperation, ParsePropInAttrDict + + +@irdl_op_definition +# pylint: disable=function-redefined +class ApplyRegisteredPassOp(IRDLOperation): + """ + See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop). + """ + + name = "transform.apply_registered_pass" + + options = prop_def(DictionaryAttr, default_value=DictionaryAttr({})) + pass_name = prop_def(StringAttr) + target = operand_def(TransformHandleType) + result = result_def(TransformHandleType) + # While this assembly format doesn't match + # the one in upstream MLIR, + # this is because xDSL currently lacks CustomDirectives + # https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives + # https://github.com/xdslproject/xdsl/pull/4829 + # However, storing the property in the attribute should still work + # specially when parsing and printing in generic format. + # Which is how Catalyst and XDSL currently communicate at the moment. + # TODO: Add test. + assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)" + irdl_options = [ParsePropInAttrDict()] + + def __init__( + self, + pass_name: str | StringAttr, + target: SSAValue, + options: dict[str | StringAttr, Attribute | str | bool | int] | None = None, + ): + if isinstance(pass_name, str): + pass_name = StringAttr(pass_name) + + if isinstance(options, dict): + options = DictionaryAttr(options) + + super().__init__( + properties={ + "pass_name": pass_name, + "options": options, + }, + operands=[target], + result_types=[target.type], + ) + + +# Copied over from xDSL's sources +# the main difference will be the use +# of a different ApplyRegisteredPassOp +operations = list(xTransform.operations) +del operations[operations.index(xApplyRegisteredPassOp)] +operations.append(ApplyRegisteredPassOp) + +Transform = Dialect( + "transform", + [ + *operations, + ], + [ + *xTransform.attributes, + ], +) diff --git a/pennylane/compiler/python_compiler/jax_utils.py b/pennylane/compiler/python_compiler/jax_utils.py index 4781a12a073..3691ff3c92d 100644 --- a/pennylane/compiler/python_compiler/jax_utils.py +++ b/pennylane/compiler/python_compiler/jax_utils.py @@ -30,12 +30,11 @@ from xdsl.dialects import scf as xscf from xdsl.dialects import stablehlo as xstablehlo from xdsl.dialects import tensor as xtensor -from xdsl.dialects import transform as xtransform from xdsl.ir import Dialect as xDialect from xdsl.parser import Parser as xParser from xdsl.traits import SymbolTable as xSymbolTable -from .dialects import MBQC, QEC, Catalyst, Quantum +from .dialects import MBQC, QEC, Catalyst, Quantum, Transform JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member @@ -59,7 +58,7 @@ class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public- xscf.Scf, xstablehlo.StableHLO, xtensor.Tensor, - xtransform.Transform, + Transform, Quantum, MBQC, Catalyst, diff --git a/pennylane/compiler/python_compiler/transforms/api/transform_interpreter.py b/pennylane/compiler/python_compiler/transforms/api/transform_interpreter.py index 9d5541d5859..824796e548a 100644 --- a/pennylane/compiler/python_compiler/transforms/api/transform_interpreter.py +++ b/pennylane/compiler/python_compiler/transforms/api/transform_interpreter.py @@ -23,7 +23,8 @@ from catalyst.compiler import _quantum_opt # pylint: disable=protected-access from xdsl.context import Context -from xdsl.dialects import builtin, transform +from xdsl.dialects import builtin +from xdsl.dialects.transform import NamedSequenceOp from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls from xdsl.interpreters.transform import TransformFunctions from xdsl.parser import Parser @@ -32,6 +33,8 @@ from xdsl.rewriter import Rewriter from xdsl.utils.exceptions import PassFailedException +from ...dialects.transform import ApplyRegisteredPassOp + # pylint: disable=too-few-public-methods @register_impls @@ -43,11 +46,11 @@ class TransformFunctionsExt(TransformFunctions): then it will try to run this pass in Catalyst. """ - @impl(transform.ApplyRegisteredPassOp) + @impl(ApplyRegisteredPassOp) def run_apply_registered_pass_op( # pragma: no cover self, _interpreter: Interpreter, - op: transform.ApplyRegisteredPassOp, + op: ApplyRegisteredPassOp, args: PythonValues, ) -> PythonValues: """Try to run the pass in xDSL, if it can't run on catalyst""" @@ -56,7 +59,7 @@ def run_apply_registered_pass_op( # pragma: no cover if pass_name in self.passes: # pragma: no cover pass_class = self.passes[pass_name]() - pipeline = PassPipeline((pass_class(),)) + pipeline = PassPipeline((pass_class(**op.options.data),)) pipeline.apply(self.ctx, args[0]) return (args[0],) @@ -86,12 +89,10 @@ def __init__(self, passes): self.passes = passes @staticmethod - def find_transform_entry_point( - root: builtin.ModuleOp, entry_point: str - ) -> transform.NamedSequenceOp: + def find_transform_entry_point(root: builtin.ModuleOp, entry_point: str) -> NamedSequenceOp: """Find the entry point of the program""" for op in root.walk(): - if isinstance(op, transform.NamedSequenceOp) and op.sym_name.data == entry_point: + if isinstance(op, NamedSequenceOp) and op.sym_name.data == entry_point: return op raise PassFailedException( # pragma: no cover f"{root} could not find a nested named sequence with name: {entry_point}" diff --git a/tests/python_compiler/dialects/test_transform_dialect.py b/tests/python_compiler/dialects/test_transform_dialect.py new file mode 100644 index 00000000000..2ffcfb828e1 --- /dev/null +++ b/tests/python_compiler/dialects/test_transform_dialect.py @@ -0,0 +1,149 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/transform.py.""" + +from dataclasses import dataclass + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from xdsl import passes +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64 +from xdsl.dialects.transform import AnyOpType +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import create_ssa_value + +from pennylane.compiler.python_compiler.dialects import transform +from pennylane.compiler.python_compiler.dialects.transform import ApplyRegisteredPassOp +from pennylane.compiler.python_compiler.jax_utils import xdsl_from_docstring +from pennylane.compiler.python_compiler.transforms.api import ( + ApplyTransformSequence, + compiler_transform, +) + + +def test_dict_options(): + """Test ApplyRegisteredPassOp constructor with dict options.""" + target = create_ssa_value(AnyOpType()) + options = {"option1": 1, "option2": True} + + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"option1": 1, "option2": True}) + assert op.verify_() is None + + +def test_attr_options(): + """Test ApplyRegisteredPassOp constructor with DictionaryAttr options.""" + target = create_ssa_value(AnyOpType()) + options = DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + assert op.verify_() is None + + +def test_none_options(): + """Test ApplyRegisteredPassOp constructor with None options.""" + target = create_ssa_value(AnyOpType()) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, None) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({}) + assert op.verify_() is None + + +def test_invalid_options(): + """Test ApplyRegisteredPassOp constructor with invalid options type.""" + target = create_ssa_value(AnyOpType()) + + with pytest.raises( + VerifyException, match="invalid_options should be of base attribute dictionary" + ): + ApplyRegisteredPassOp("canonicalize", target, "invalid_options").verify_() + + +def test_transform_dialect_filecheck(run_filecheck): + """Test that the transform dialect operations are parsed correctly.""" + program = """ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + // CHECK: options = {"invalid-option" = 1 : i64} + %1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + """ + + run_filecheck(program) + + +def test_integration_for_transform_interpreter(capsys): + """Test that a pass with options is run via the transform interpreter""" + + @compiler_transform + @dataclass(frozen=True) + class _HelloWorld(passes.ModulePass): + name = "hello-world" + + custom_print: str | None = None + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + if self.custom_print: + print(self.custom_print) + else: + print("hello world") + + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = "transform.apply_registered_pass"(%arg0) <{options = {"custom_print" = "Hello from custom option!"}, pass_name = "hello-world"}> : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = xdsl.context.Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + + mod = program() + pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) + pipeline.apply(ctx, mod) + + assert "Hello from custom option!" in capsys.readouterr().out diff --git a/tests/python_compiler/test_python_compiler.py b/tests/python_compiler/test_python_compiler.py index 58e5a435096..f223da52e66 100644 --- a/tests/python_compiler/test_python_compiler.py +++ b/tests/python_compiler/test_python_compiler.py @@ -29,10 +29,11 @@ from catalyst import CompileError from xdsl import passes from xdsl.context import Context -from xdsl.dialects import builtin, transform +from xdsl.dialects import builtin from xdsl.interpreters import Interpreter from pennylane.compiler.python_compiler import Compiler +from pennylane.compiler.python_compiler.dialects import transform from pennylane.compiler.python_compiler.jax_utils import ( jax_from_docstring, module,