Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c602a54
Use the transform dialect from our repository instead of the xDSL one
erick-xanadu Jul 24, 2025
eb94a67
Redefine Transform
erick-xanadu Jul 24, 2025
ebade75
copies over ApplyRegisteredPassOp
erick-xanadu Jul 24, 2025
966a678
Use DictionaryAttr instead of StringAttr
erick-xanadu Jul 24, 2025
c3905c6
Redefine ApplyRegisteredPassOp
erick-xanadu Jul 24, 2025
2e21cb8
initial test
erick-xanadu Jul 24, 2025
0d1f4fa
Apply suggestions from code review
erick-xanadu Jul 24, 2025
9ecadb7
Apply suggestions from code review
erick-xanadu Jul 24, 2025
8c62649
Fix imports
mehrdad2m Jul 31, 2025
6b83c17
move test to separate file
mehrdad2m Jul 31, 2025
25dd616
fix type hints
mehrdad2m Aug 1, 2025
3397df1
Merge branch 'master' into eochoa/2025-07-24/transform-dialect-update
mehrdad2m Aug 1, 2025
701f80b
isort after rebase
mehrdad2m Aug 1, 2025
c0fda72
fix type hint
mehrdad2m Aug 1, 2025
6f3e1ad
Merge branch 'master' into eochoa/2025-07-24/transform-dialect-update
mehrdad2m Aug 1, 2025
0b7b631
import original transform separately
mehrdad2m Aug 1, 2025
58a4a0f
remove redundant code
mehrdad2m Aug 1, 2025
9f51c7b
check full option dict
mehrdad2m Aug 1, 2025
b0a3120
Update transform_interpreter to utilize full options data for PassPip…
mehrdad2m Aug 5, 2025
262b32b
Add more unit tests for ApplyRegisteredPassOp with various options
mehrdad2m Aug 5, 2025
28c98e8
Merge branch 'master' into eochoa/2025-07-24/transform-dialect-update
mehrdad2m Aug 7, 2025
8558383
Merge branch 'master' into eochoa/2025-07-24/transform-dialect-update
mehrdad2m Aug 14, 2025
6e94592
remove wildcard import
mehrdad2m Aug 14, 2025
d3b9139
fix a mistake
mehrdad2m Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pennylane/compiler/python_compiler/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

"""This submodule contains xDSL dialects for the Python compiler."""

from .catalyst import Catalyst
from .mbqc import MBQC
from .quantum import Quantum
from .catalyst import Catalyst
from .transform import Transform


__all__ = ["Catalyst", "MBQC", "Quantum"]
__all__ = ["Catalyst", "MBQC", "Quantum", "Transform"]
103 changes: 103 additions & 0 deletions pennylane/compiler/python_compiler/dialects/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains an updated version of the transform dialect.
As of the time of writing, xDSL uses the MLIR released with LLVM's
version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916
will be updating MLIR where the transform dialect has the
`apply_registered_pass` operation re-defined.

See the following changelog on the above PR

Things related to transform.apply_registered_pass op:

It now takes in a dynamic_options

[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as
a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static.
https://github.com/llvm/llvm-project/pull/142683

The options it takes in are now dictionaries instead of strings
[MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159
https://github.com/llvm/llvm-project/pull/143159

This file will re-define the apply_registered_pass operation in xDSL
and the transform dialect.

Once xDSL moves to a newer version of MLIR, these changes should
be contributed upstream.
"""

# pylint: disable=unused-wildcard-import,wildcard-import,undefined-variable,too-few-public-methods
from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp
from xdsl.dialects.transform import Transform as xTransform
from xdsl.dialects.transform import *


@irdl_op_definition
# pylint: disable=function-redefined
class ApplyRegisteredPassOp(IRDLOperation):
"""
See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop).
"""

name = "transform.apply_registered_pass"

options = prop_def(DictionaryAttr, default_value=DictionaryAttr({}))
pass_name = prop_def(StringAttr)
target = operand_def(TransformHandleType)
result = result_def(TransformHandleType)
# While this assembly format doesn't match
# the one in upstream MLIR,
# this is because xDSL currently lacks CustomDirectives
# https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives
# https://github.com/xdslproject/xdsl/pull/4829
# However, storing the property in the attribute should still work
# specially when parsing and printing in generic format.
# Which is how Catalyst and XDSL currently communicate at the moment.
# TODO: Add test.
assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)"
irdl_options = [ParsePropInAttrDict()]

def __init__(
self,
pass_name: str | StringAttr,
target: SSAValue,
options: dict[str | StringAttr, Attribute | str | bool | int] | None = None,
):
if isinstance(pass_name, str):
pass_name = StringAttr(pass_name)

if isinstance(options, dict):
options = DictionaryAttr(options)

super().__init__(
properties={
"pass_name": pass_name,
"options": options,
},
operands=[target],
result_types=[target.type],
)


# Copied over from xDSL's sources
# the main difference will be the use
# of a different ApplyRegisteredPassOp
operations = list(xTransform.operations)
del operations[operations.index(xApplyRegisteredPassOp)]
operations.append(ApplyRegisteredPassOp)
attributes = list(xTransform.attributes)

Transform = Dialect("transform", operations, attributes)
5 changes: 2 additions & 3 deletions pennylane/compiler/python_compiler/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@
from xdsl.dialects import scf as xscf
from xdsl.dialects import stablehlo as xstablehlo
from xdsl.dialects import tensor as xtensor
from xdsl.dialects import transform as xtransform
from xdsl.ir import Dialect as xDialect
from xdsl.parser import Parser as xParser
from xdsl.traits import SymbolTable as xSymbolTable

from .dialects import MBQC, Catalyst, Quantum
from .dialects import MBQC, Catalyst, Quantum, Transform

JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member

Expand All @@ -59,7 +58,7 @@ class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public-
xscf.Scf,
xstablehlo.StableHLO,
xtensor.Tensor,
xtransform.Transform,
Transform,
Quantum,
MBQC,
Catalyst,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from catalyst.compiler import _quantum_opt # pylint: disable=protected-access
from xdsl.context import Context
from xdsl.dialects import builtin, transform
from xdsl.dialects import builtin
from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls
from xdsl.interpreters.transform import TransformFunctions
from xdsl.parser import Parser
Expand All @@ -32,6 +32,8 @@
from xdsl.rewriter import Rewriter
from xdsl.utils.exceptions import PassFailedException

from ...dialects import transform


# pylint: disable=too-few-public-methods
@register_impls
Expand Down
42 changes: 42 additions & 0 deletions tests/python_compiler/dialects/test_transform_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit test module for pennylane/compiler/python_compiler/transform.py."""

import pytest

# pylint: disable=wrong-import-position

xdsl = pytest.importorskip("xdsl")
filecheck = pytest.importorskip("filecheck")

pytestmark = pytest.mark.external


def test_transform_dialect_update(run_filecheck):
"""Test that the transform dialect is updated correctly."""

program = """
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
// CHECK: "invalid-option"
%1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
"""

run_filecheck(program, ())
3 changes: 2 additions & 1 deletion tests/python_compiler/test_python_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from catalyst import CompileError
from xdsl import passes
from xdsl.context import Context
from xdsl.dialects import builtin, transform
from xdsl.dialects import builtin
from xdsl.interpreters import Interpreter

from pennylane.compiler.python_compiler import Compiler
from pennylane.compiler.python_compiler.dialects import transform
from pennylane.compiler.python_compiler.jax_utils import (
jax_from_docstring,
module,
Expand Down
Loading