diff --git a/.dep-versions b/.dep-versions
index d755906e6d..849b989020 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -2,9 +2,9 @@
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
-mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
-llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
-enzyme=v0.0.180
+mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
+llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
+enzyme=v0.0.186
# Always remove custom PL/LQ versions before release.
diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml
index df53cd5880..ef4b3d04ce 100644
--- a/.github/workflows/build-wheel-linux-arm64.yaml
+++ b/.github/workflows/build-wheel-linux-arm64.yaml
@@ -108,6 +108,12 @@ jobs:
ref: ${{ needs.constants.outputs.llvm_version }}
path: ${{ github.workspace }}/mlir/llvm-project
+ - name: Patch LLVM Source
+ if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+ run: |
+ cd $GITHUB_WORKSPACE/mlir/llvm-project
+ git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
- name: Clone MHLO Submodule
if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
uses: actions/checkout@v4
@@ -122,6 +128,7 @@ jobs:
cd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+ git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
- name: Clone Enzyme Submodule
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -134,9 +141,8 @@ jobs:
- name: Patch Enzyme Source
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
run: |
- export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
- export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
- patch -p1 $TARGET_FILE $PATCH_FILE
+ cd $GITHUB_WORKSPACE/mlir/Enzyme
+ git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
# Cache external project builds
- name: Restore LLVM Build
diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml
index 35be761c99..bde5f9cbbc 100644
--- a/.github/workflows/build-wheel-linux-x86_64.yaml
+++ b/.github/workflows/build-wheel-linux-x86_64.yaml
@@ -127,6 +127,12 @@ jobs:
ref: ${{ needs.constants.outputs.llvm_version }}
path: ${{ github.workspace }}/mlir/llvm-project
+ - name: Patch LLVM Source
+ if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+ run: |
+ cd $GITHUB_WORKSPACE/mlir/llvm-project
+ git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
- name: Clone MHLO Submodule
if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
uses: actions/checkout@v4
@@ -141,6 +147,7 @@ jobs:
cd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+ git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
- name: Clone Enzyme Submodule
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -153,9 +160,8 @@ jobs:
- name: Patch Enzyme Source
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
run: |
- export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
- export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
- patch -p1 $TARGET_FILE $PATCH_FILE
+ cd $GITHUB_WORKSPACE/mlir/Enzyme
+ git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
# Cache external project builds
- name: Restore LLVM Build
diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml
index d6d2e9d215..fdc101e3a8 100644
--- a/.github/workflows/build-wheel-macos-arm64.yaml
+++ b/.github/workflows/build-wheel-macos-arm64.yaml
@@ -113,6 +113,12 @@ jobs:
ref: ${{ needs.constants.outputs.llvm_version }}
path: ${{ github.workspace }}/mlir/llvm-project
+ - name: Patch LLVM Source
+ if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+ run: |
+ cd $GITHUB_WORKSPACE/mlir/llvm-project
+ git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
- name: Clone MHLO Submodule
if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
uses: actions/checkout@v4
@@ -127,6 +133,7 @@ jobs:
cd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+ git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
- name: Clone Enzyme Submodule
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -139,9 +146,8 @@ jobs:
- name: Patch Enzyme Source
if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
run: |
- export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
- export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
- patch -p1 $TARGET_FILE $PATCH_FILE
+ cd $GITHUB_WORKSPACE/mlir/Enzyme
+ git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
# Cache external project builds
- name: Restore LLVM Build
diff --git a/Makefile b/Makefile
index e3309167f0..68d81ae519 100644
--- a/Makefile
+++ b/Makefile
@@ -282,6 +282,9 @@ clean-plugin:
clean-llvm:
$(MAKE) -C mlir clean-llvm
+reset-llvm:
+ $(MAKE) -C mlir reset-llvm
+
clean-mhlo:
$(MAKE) -C mlir clean-mhlo
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 32688e2dab..2c03b31a80 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -12,6 +12,16 @@
* The JAX version used by Catalyst is updated to 0.6.2.
[(#1897)](https://github.com/PennyLaneAI/catalyst/pull/1897)
+* The version of LLVM, mlir-hlo, and Enzyme used by Catalyst has been updated.
+ [(#1916)](https://github.com/PennyLaneAI/catalyst/pull/1916)
+
+ The LLVM version has been updated to
+ [commit f8cb798](https://github.com/llvm/llvm-project/tree/f8cb7987c64dcffb72414a40560055cb717dbf74).
+ The mlir-hlo version has been updated to
+ [commit 1dd2e71](https://github.com/tensorflow/mlir-hlo/tree/1dd2e71331014ae0373f6bf900ce6be393357190).
+ The Enzyme version has been updated to
+ [v0.0.186](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186).
+
Deprecations 👋
Bug fixes 🐛
diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py
index d2bfbc74bb..7dc1382593 100644
--- a/frontend/catalyst/jax_extras/lowering.py
+++ b/frontend/catalyst/jax_extras/lowering.py
@@ -16,6 +16,7 @@
from __future__ import annotations
import logging
+import textwrap
import jax
from jax._src.dispatch import jaxpr_replicas
@@ -38,6 +39,7 @@
import catalyst
from catalyst.logging import debug_logger
+from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
# pylint: disable=protected-access
@@ -165,3 +167,57 @@ def custom_lower_jaxpr_to_module(
worklist += [*op.body.operations]
return ctx.module, ctx.context
+
+
+def get_mlir_attribute_from_pyval(value):
+ """
+ Given a value of any type, construct an mlir attribute of corresponding type.
+
+ We set up the context and location outside because recursive calls to this function
+ will segfault if multiple `Context()`s are instantiated.
+ """
+
+ attr = None
+ match value:
+ case bool():
+ attr = ir.BoolAttr.get(value)
+
+ case int():
+ if -9223372036854775808 <= value < 0: # 2**63
+ attr = ir.IntegerAttr.get(ir.IntegerType.get_signed(64), value)
+ elif 0 <= value < 18446744073709551616: # = 2**64
+ attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
+ else:
+ raise CompileError(
+ textwrap.dedent(
+ """
+ Large interger attributes currently not supported in MLIR,
+ see https://github.com/llvm/llvm-project/issues/128072
+ """
+ )
+ )
+
+ case float():
+ attr = ir.FloatAttr.get(ir.F64Type.get(), value)
+
+ case str():
+ attr = ir.StringAttr.get(value)
+
+ case list() | tuple():
+ element_attrs = [get_mlir_attribute_from_pyval(elem) for elem in value]
+ attr = ir.ArrayAttr.get(element_attrs)
+
+ case dict():
+ named_attrs = {}
+ for k, v in value.items():
+ if not isinstance(k, str):
+ raise CompileError(
+ f"Dictionary keys for MLIR DictionaryAttr must be strings, got: {type(k)}"
+ )
+ named_attrs[k] = get_mlir_attribute_from_pyval(v)
+ attr = ir.DictAttr.get(named_attrs)
+
+ case _:
+ raise CompileError(f"Cannot convert Python type {type(value)} to an MLIR attribute.")
+
+ return attr
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index cd2d27eb53..effa498b45 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -280,7 +280,11 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
for _pass in pipeline:
options = _pass.get_options()
apply_registered_pass_op = ApplyRegisteredPassOp(
- result=transform_mod_type, target=target, pass_name=_pass.name, options=options
+ result=transform_mod_type,
+ target=target,
+ pass_name=_pass.name,
+ options=options,
+ dynamic_options={},
)
target = apply_registered_pass_op.result
transform_yield_op = YieldOp(operands_=[]) # pylint: disable=unused-variable
diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py
index 423f8c5269..b0ed5db380 100644
--- a/frontend/catalyst/passes/pass_api.py
+++ b/frontend/catalyst/passes/pass_api.py
@@ -19,6 +19,7 @@
import pennylane as qml
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
from catalyst.tracing.contexts import EvaluationContext
PipelineDict: TypeAlias = dict[str, dict[str, str]]
@@ -286,23 +287,18 @@ def __init__(self, name: str, *options: list[str], **valued_options: dict[str, s
def get_options(self):
"""
- Stringify options according to what mlir-opt expects.
-
- ApplyRegisteredPassOp expects options to be a single StringAttr
- which follows the same format as the one used with mlir-opt.
-
- https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop
-
- Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...},
- i.e., use space-separated key=value pairs for each option.
+ Build a dictionary mapping option names to MLIR attributes.
+ ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes.
+ See https://github.com/llvm/llvm-project/pull/143159
+ """
+ options_dict = {}
+ for option in self.options:
+ options_dict[str(option)] = get_mlir_attribute_from_pyval(True)
- https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options
+ for option, value in self.valued_options.items():
+ options_dict[str(option)] = get_mlir_attribute_from_pyval(value)
- Experimentally we found that single-options also work without values.
- """
- retval = " ".join(f"{str(option)}" for option in self.options)
- retval2 = " ".join(f"{str(key)}={str(value)}" for key, value in self.valued_options.items())
- return " ".join([retval, retval2]).strip()
+ return options_dict
def __repr__(self):
return (
diff --git a/frontend/test/lit/test_mlir_plugin.py b/frontend/test/lit/test_mlir_plugin.py
index 0be8d5cbb1..09d519c7ef 100644
--- a/frontend/test/lit/test_mlir_plugin.py
+++ b/frontend/test/lit/test_mlir_plugin.py
@@ -106,7 +106,7 @@ def test_pass_options():
"""Is the option in the generated MLIR?"""
@qjit(target="mlir")
- # CHECK: options = "an-option maxValue=1"
+ # CHECK: options = {"an-option" = true, "maxValue" = 1 : i64}
@catalyst.passes.apply_pass("some-pass", "an-option", maxValue=1)
@qml.qnode(qml.device("null.qubit", wires=1))
def example():
diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py
index 99b78f29b6..37e2d0762a 100644
--- a/frontend/test/pytest/test_jax_integration.py
+++ b/frontend/test/pytest/test_jax_integration.py
@@ -14,15 +14,19 @@
"""Test QJIT compatibility with JAX transformations such as jax.jit and jax.grad."""
+import textwrap
from functools import partial
import jax
import jax.numpy as jnp
import pennylane as qml
import pytest
+from jax.interpreters.mlir import ir
from catalyst import for_loop, measure, qjit
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
from catalyst.jit import JAX_QJIT
+from catalyst.utils.exceptions import CompileError
class TestJAXJIT:
@@ -490,5 +494,122 @@ def ansatz(i, x):
jax.grad(circuit, argnums=0)(params, 3)
+ctx = ir.Context()
+loc = ir.Location.unknown(ctx)
+
+
+class TestJAXMLIRAttributeGetter:
+ """
+ Test catalyst.jax_extras.lowering.get_mlir_attribute_from_pyval
+ """
+
+ def test_bool_attr(self):
+ """
+ Test bool attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval(True)
+ assert isinstance(attr, ir.BoolAttr)
+ assert attr.value == True
+
+ def test_str_attr(self):
+ """
+ Test string attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval("hello catalyst!")
+ assert isinstance(attr, ir.StringAttr)
+ assert attr.value == "hello catalyst!"
+
+ @pytest.mark.parametrize("number", (37, -37))
+ def test_int_attr(self, number):
+ """
+ Test integer attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval(number)
+ assert isinstance(attr, ir.IntegerAttr)
+ assert attr.value == number
+
+ @pytest.mark.parametrize("number", (3.7, -3.7))
+ def test_float_attr(self, number):
+ """
+ Test float attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval(number)
+ assert isinstance(attr, ir.FloatAttr)
+ assert attr.value == number
+
+ @pytest.mark.parametrize("array", ([1, 2, 3], (4, 5, 6)))
+ def test_array_attr(self, array):
+ """
+ Test array attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval(array)
+ assert isinstance(attr, ir.ArrayAttr)
+ assert len(attr) == len(array)
+
+ for attr_val, py_val in zip(attr, array):
+ assert isinstance(attr_val, ir.IntegerAttr)
+ assert attr_val.value == py_val
+
+ def test_dict_attr(self):
+ """
+ Test dictionary attribute.
+ """
+ with ctx, loc:
+ attr = get_mlir_attribute_from_pyval(
+ {"device": "lightning.qubit", "wire_capacity": 100}
+ )
+ assert isinstance(attr, ir.DictAttr)
+
+ assert isinstance(attr["device"], ir.StringAttr)
+ assert attr["device"].value == "lightning.qubit"
+
+ assert isinstance(attr["wire_capacity"], ir.IntegerAttr)
+ assert attr["wire_capacity"].value == 100
+
+ def test_dict_attr_with_bad_keys(self):
+ """
+ Test dictionary attribute with non-string keys.
+ """
+ with pytest.raises(
+ CompileError, match="Dictionary keys for MLIR DictionaryAttr must be strings"
+ ):
+ with ctx, loc:
+ _ = get_mlir_attribute_from_pyval({37: 42})
+
+ def test_bad_type(self):
+ """
+ Test an error is correctly raised on a python type not convertible to mlir attribute.
+ """
+
+ # pylint: disable=missing-class-docstring
+ class Foo:
+ pass
+
+ with pytest.raises(CompileError, match="Cannot convert Python type"):
+ with ctx, loc:
+ _ = get_mlir_attribute_from_pyval(Foo())
+
+ def test_int_attr_overflow(self):
+ """
+ Test int attribute with overflow correctly raises error.
+ """
+ with pytest.raises(
+ CompileError,
+ match=textwrap.dedent(
+ """
+ Large interger attributes currently not supported in MLIR,
+ see https://github.com/llvm/llvm-project/issues/128072
+ """
+ ),
+ ):
+ with ctx, loc:
+ _ = get_mlir_attribute_from_pyval(2**100)
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/frontend/test/pytest/test_mlir_plugin_interface.py b/frontend/test/pytest/test_mlir_plugin_interface.py
index 30f7c040b9..afde4c920f 100644
--- a/frontend/test/pytest/test_mlir_plugin_interface.py
+++ b/frontend/test/pytest/test_mlir_plugin_interface.py
@@ -19,6 +19,7 @@
import pennylane as qml
import pytest
+from jax.interpreters.mlir import ir
import catalyst
from catalyst import qjit
@@ -73,24 +74,26 @@ def test_get_options():
"""
Test get_options from Pass
- ApplyRegisteredPassOp expects options to be a single StringAttr
- which follows the same format as the one used with mlir-opt.
-
- https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop
-
- Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...},
- i.e., use space-separated key=value pairs for each option.
-
- https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options
-
- However, experimentally we found that single-options also work without values.
+ ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes.
+ See https://github.com/llvm/llvm-project/pull/143159
"""
- assert catalyst.passes.Pass("example-pass", "single-option").get_options() == "single-option"
- assert (
- catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options()
- == "an-option bn-option"
- )
- assert catalyst.passes.Pass("example-pass", option=True).get_options() == "option=True"
+ with ir.Context(), ir.Location.unknown():
+ options = catalyst.passes.Pass("example-pass", "single-option").get_options()
+ assert isinstance(options, dict)
+ assert isinstance(options["single-option"], ir.BoolAttr)
+ assert options["single-option"].value == True
+
+ options = catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options()
+ assert isinstance(options, dict)
+ assert isinstance(options["an-option"], ir.BoolAttr)
+ assert options["an-option"].value == True
+ assert isinstance(options["bn-option"], ir.BoolAttr)
+ assert options["bn-option"].value == True
+
+ options = catalyst.passes.Pass("example-pass", option=True).get_options()
+ assert isinstance(options, dict)
+ assert isinstance(options["option"], ir.BoolAttr)
+ assert options["option"].value == True
@pytest.mark.skip(reason="xdsl not installed in ci cd yet")
diff --git a/mlir/Enzyme b/mlir/Enzyme
index db0181320d..8c1a596158 160000
--- a/mlir/Enzyme
+++ b/mlir/Enzyme
@@ -1 +1 @@
-Subproject commit db0181320d6e425ee963bd496ed0d8dbb615be18
+Subproject commit 8c1a596158f6194f10e8ffd56a1660a61c54337e
diff --git a/mlir/Makefile b/mlir/Makefile
index 5b3dc53f53..0a41527370 100644
--- a/mlir/Makefile
+++ b/mlir/Makefile
@@ -57,6 +57,14 @@ all: llvm mhlo enzyme dialects plugin
.PHONY: llvm
llvm:
@echo "build LLVM and MLIR enabling Python bindings"
+
+ # Patch mlir one shot bufferization segfault
+ # Remove patch after bug is resolved upstream
+ # https://github.com/llvm/llvm-project/issues/150441
+ @if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-bufferization-segfault.patch; then \
+ git apply $(MK_DIR)/patches/llvm-bufferization-segfault.patch; \
+ fi
+
cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_BUILD_EXAMPLES=OFF \
@@ -97,6 +105,11 @@ mhlo:
@if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \
git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \
fi
+
+ # Patch a MHLO bug with std::sort
+ @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-rename-sort.patch; then \
+ git apply $(MK_DIR)/patches/mhlo-rename-sort.patch; \
+ fi
cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_ENABLE_ASSERTIONS=ON \
@@ -121,8 +134,8 @@ enzyme: PATCH_FILE := $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch
enzyme:
@echo "build enzyme"
# Patch enzyme's dependency on nvidia fabs llvm intrinsics
- @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \
- patch -p1 $(TARGET_FILE) $(PATCH_FILE); \
+ @if cd Enzyme; git apply --check $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; then \
+ git apply $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; \
fi
cmake -G Ninja -S Enzyme/enzyme -B $(ENZYME_BUILD_DIR) \
-DENZYME_STATIC_LIB=ON \
@@ -204,6 +217,11 @@ clean-dialects:
clean-llvm:
@echo "clean llvm/mlir build files"
rm -rf $(LLVM_BUILD_DIR)
+ cd llvm-project; git clean -fd; git checkout .
+
+reset-llvm:
+ @echo "reset llvm git state to the commit tracked in .dep-versions without deleting llvm builds"
+ cd llvm-project; git clean -fd; git checkout .
clean-mhlo:
@echo "clean HLO dialect build files"
@@ -213,6 +231,7 @@ clean-mhlo:
clean-enzyme:
@echo "clean enzyme build files"
rm -rf $(ENZYME_BUILD_DIR)
+ cd Enzyme; git clean -fd; git checkout .
clean-plugin:
@echo "clean plugin"
diff --git a/mlir/include/Catalyst/Transforms/AsyncUtils.h b/mlir/include/Catalyst/Transforms/AsyncUtils.h
index 7be6815235..98df1f3972 100644
--- a/mlir/include/Catalyst/Transforms/AsyncUtils.h
+++ b/mlir/include/Catalyst/Transforms/AsyncUtils.h
@@ -63,13 +63,13 @@ bool hasAbortInBlock(Block *block);
bool hasPutsInBlock(Block *block);
// Helper function for creating function declarations
-LLVM::LLVMFuncOp lookupOrCreatePersonality(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateAbort(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(ModuleOp);
-LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(ModuleOp);
-LLVM::LLVMFuncOp lookupOrCreateDropRef(ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreateDropRef(OpBuilder &b, ModuleOp);
}; // namespace AsyncUtils
diff --git a/mlir/include/Gradient/Utils/DestinationPassingStyle.h b/mlir/include/Gradient/Utils/DestinationPassingStyle.h
index 920cfc45f3..7a6ae53415 100644
--- a/mlir/include/Gradient/Utils/DestinationPassingStyle.h
+++ b/mlir/include/Gradient/Utils/DestinationPassingStyle.h
@@ -17,5 +17,6 @@
namespace catalyst {
/// Convert every MemRef-typed return value in `callee` to writing to a new argument in
/// destination-passing style.
-void convertToDestinationPassingStyle(mlir::func::FuncOp callee, mlir::OpBuilder &builder);
+llvm::LogicalResult convertToDestinationPassingStyle(mlir::func::FuncOp callee,
+ mlir::OpBuilder &builder);
} // namespace catalyst
diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
index 76cee871d4..ad3ec61a72 100644
--- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
+++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
@@ -128,81 +128,84 @@ LLVM::LLVMFuncOp AsyncUtils::getCaller(LLVM::CallOp callOp)
return callOp->getParentOfType();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto i32Ty = IntegerType::get(ctx, 32);
bool isVarArg = true;
- return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::personalityName, {}, i32Ty,
- isVarArg)
+ return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::personalityName, {},
+ i32Ty, isVarArg)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
+ return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
- moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
+ b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
- moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
+ b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
Type llvmInt64Type = IntegerType::get(moduleOp.getContext(), 64);
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
+ return mlir::LLVM::lookupOrCreateFn(b, moduleOp,
+ AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
{ptrTy, llvmInt64Type}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b,
+ ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
- moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
+ b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b,
+ ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
- moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
+ b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
.value();
}
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::unrecoverableErrorName, {},
- voidTy)
+ return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::unrecoverableErrorName,
+ {}, voidTy)
.value();
}
diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
index c5b93d1602..ebd559a188 100644
--- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -56,11 +56,12 @@ struct PrintOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto printOp = cast(op);
if (printOp.getVal()) {
- FailureOr source = getBuffer(rewriter, printOp.getVal(), options);
+ FailureOr source = getBuffer(rewriter, printOp.getVal(), options, state);
if (failed(source)) {
return failure();
}
@@ -116,7 +117,8 @@ struct CustomCallOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto customCallOp = cast(op);
@@ -124,7 +126,7 @@ struct CustomCallOpInterface
SmallVector bufferArgs;
ValueRange operands = customCallOp.getOperands();
for (Value operand : operands) {
- FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
if (failed(opBuffer)) {
return failure();
}
@@ -165,11 +167,11 @@ struct CustomCallOpInterface
}
auto options = bufferization::BufferizationOptions();
FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
- rewriter, op->getLoc(), result, options, false);
+ rewriter, op->getLoc(), result, options, state, false);
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto newBuffer =
- rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
+ rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
bufferArgs.push_back(newBuffer);
}
@@ -207,7 +209,8 @@ struct CallbackOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto callbackOp = cast(op);
@@ -279,7 +282,8 @@ struct CallbackCallOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto callOp = cast(op);
@@ -292,7 +296,7 @@ struct CallbackCallOpInterface
SmallVector newInputs;
auto operands = callOp.getOperands();
for (Value operand : operands) {
- FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
if (failed(opBuffer)) {
return failure();
}
@@ -303,8 +307,8 @@ struct CallbackCallOpInterface
auto loc = callOp->getLoc();
SmallVector outmemrefs;
for (auto result : results) {
- FailureOr tensorAlloc =
- bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false);
+ FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
+ rewriter, loc, result, options, state, false);
if (failed(tensorAlloc)) {
return failure();
}
@@ -314,8 +318,8 @@ struct CallbackCallOpInterface
auto shape = tensorTy.getShape();
auto elementTy = tensorTy.getElementType();
auto memrefType = MemRefType::get(shape, elementTy);
- auto toMemrefOp = rewriter.create(loc, memrefType, tensor);
- auto memref = toMemrefOp.getResult();
+ auto toBufferOp = rewriter.create(loc, memrefType, tensor);
+ auto memref = toBufferOp.getResult();
outmemrefs.push_back(memref);
newInputs.push_back(memref);
}
diff --git a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
index d516605b4c..160da0d118 100644
--- a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
+++ b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
@@ -133,7 +133,7 @@ LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp
auto moduleOp = callOp->getParentOfType();
// Here, we are adding a reference to the personality declaration.
// From the documentation: https://llvm.org/docs/ExceptionHandling.html#exception-tables
- auto personality = AsyncUtils::lookupOrCreatePersonality(moduleOp);
+ auto personality = AsyncUtils::lookupOrCreatePersonality(rewriter, moduleOp);
// We annotate the body of the function containing the callop to have a reference
// to the personality.
@@ -294,7 +294,7 @@ RemoveAbortAndPutsInsertCallTransform::matchAndRewrite(LLVM::CallOp callOp,
// Here, we are declaring an external function which is available in the Catalyst runtime.
// llvm.func @__catalyst__host__rt__unrecoverable_error()
auto moduleOp = callOp->getParentOfType();
- auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(moduleOp);
+ auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(rewriter, moduleOp);
auto callee = maybeCallee.value();
rewriter.modifyOpInPlace(callee, [&] { callee.setLinkage(LLVM::Linkage::Internal); });
@@ -516,8 +516,8 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink,
// llvm.func @mlirAsyncRuntimeAwaitValue(!llvm.ptr)
// llvm.func @mlirAsyncRuntimeAwaitToken(!llvm.ptr)
// llvm.func @mlirAsyncRuntimeDropRef(!llvm.ptr, i64)
- auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(moduleOp);
- auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(moduleOp);
+ auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(rewriter, moduleOp);
+ auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(rewriter, moduleOp);
Type llvmInt64Type = IntegerType::get(sink->getContext(), 64);
auto one = rewriter.getIntegerAttr(llvmInt64Type, 1);
@@ -871,9 +871,9 @@ void insertErrorCalls(std::vector tokens, std::vector values, Bloc
auto moduleOp = landingPad->getParentOfType();
LLVM::LLVMFuncOp setTokenError =
- AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(moduleOp);
+ AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(rewriter, moduleOp);
LLVM::LLVMFuncOp setValueError =
- AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(moduleOp);
+ AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(rewriter, moduleOp);
for (auto token : tokens) {
insertCallToMlirAsyncRuntimeErrorFunction(token, setTokenError, failBlock, rewriter);
}
@@ -918,11 +918,8 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context);
GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
- // TODO: Update to the following lines the next time we update llvm
- // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
- // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+ config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+ config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns1), config))) {
signalPassFailure();
diff --git a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
index a3a6af1139..773621f132 100644
--- a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
+++ b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
@@ -29,11 +29,12 @@ struct GEPOpRewritePattern : public mlir::OpRewritePattern {
mlir::PatternRewriter &rewriter) const override
{
auto defOp = op.getBase().getDefiningOp();
- if (op.getInbounds() || (defOp && isa(defOp))) {
+ if (op.getNoWrapFlags() == LLVM::GEPNoWrapFlags::inbounds ||
+ (defOp && isa(defOp))) {
return failure();
}
rewriter.startOpModification(op);
- op.setInbounds(true);
+ op.setNoWrapFlags(LLVM::GEPNoWrapFlags::inbounds);
rewriter.finalizeOpModification(op);
return success();
}
diff --git a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
index f540222b77..f115efb171 100644
--- a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
+++ b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
@@ -380,11 +380,8 @@ struct AnnotateWithFullyQualifiedNamePass
{
MLIRContext *context = &getContext();
GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
- // TODO: Update to the following lines the next time we update llvm
- // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
- // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+ config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+ config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
RewritePatternSet annotate(context);
auto root = getOperation();
@@ -409,11 +406,8 @@ struct InlineNestedSymbolTablePass : PassWrapper(op);
}) != backwardSlice.end();
@@ -132,8 +133,8 @@ struct MemrefLoadTBAARewritePattern : public ConvertOpToLLVMPattern(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false,
loadOp.getNontemporal());
@@ -170,8 +171,8 @@ struct MemrefStoreTBAARewritePattern : public ConvertOpToLLVMPattern(storeOp, adaptor.getValue(), dataPtr,
0, false, storeOp.getNontemporal());
diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
index d427d49386..5d884162fc 100644
--- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
+++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
@@ -51,7 +51,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe
}
return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
type, rewriter.create(loc, glb),
- ArrayRef{0, 0}, true);
+ ArrayRef{0, 0},
+ LLVM::GEPNoWrapFlags::inbounds);
}
enum NumericType : int8_t {
@@ -309,7 +310,8 @@ Value EncodeDataMemRef(Location loc, PatternRewriter &rewriter, MemRefType memre
MemRefDescriptor desc = MemRefDescriptor(memrefLlvm);
Value c0 = rewriter.create(loc, rewriter.getI64IntegerAttr(0));
Value data = rewriter.create(loc, ptr, memrefType.getElementType(),
- desc.alignedPtr(rewriter, loc), c0, true);
+ desc.alignedPtr(rewriter, loc), c0,
+ LLVM::GEPNoWrapFlags::inbounds);
memref = rewriter.create(loc, memref, data, 1);
// Dtype
@@ -335,7 +337,8 @@ struct CustomCallOpPattern : public OpConversionPattern {
rewriter.setInsertionPointToStart(mod.getBody());
LLVM::LLVMFuncOp customCallFnOp =
- mlir::LLVM::lookupOrCreateFn(mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr},
+ mlir::LLVM::lookupOrCreateFn(rewriter, mod, op.getCallTargetName(),
+ {/*args=*/ptr, /*rets=*/ptr},
/*ret_type=*/voidType)
.value();
customCallFnOp.setPrivate();
@@ -467,7 +470,7 @@ struct DefineCallbackOpPattern : public OpConversionPattern {
ModuleOp mod = op->getParentOfType();
auto typeConverter = getTypeConverter();
LLVM::LLVMFuncOp customCallFnOp =
- mlir::LLVM::lookupOrCreateFn(mod, "__catalyst_inactive_callback",
+ mlir::LLVM::lookupOrCreateFn(rewriter, mod, "__catalyst_inactive_callback",
{/*args=*/i64, i64, i64},
/*ret_type=*/voidType, isVarArg)
.value();
diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
index d1980a658e..54cbc85aac 100644
--- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -177,7 +177,8 @@ struct AdjointOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto adjointOp = cast(op);
Location loc = adjointOp.getLoc();
@@ -207,7 +208,7 @@ struct AdjointOpInterface
ValueRange operands = adjointOp.getArgs();
for (Value operand : operands) {
if (isa(operand.getType())) {
- FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
if (failed(opBuffer)) {
return failure();
}
@@ -276,7 +277,8 @@ struct BackpropOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto backpropOp = cast(op);
Location loc = backpropOp.getLoc();
@@ -293,7 +295,7 @@ struct BackpropOpInterface
ValueRange operands = backpropOp.getArgs();
for (Value operand : operands) {
if (isa(operand.getType())) {
- FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
if (failed(opBuffer)) {
return failure();
}
@@ -333,7 +335,7 @@ struct BackpropOpInterface
ValueRange cotangents = backpropOp.getCotangents();
SmallVector bufferCotangents;
for (Value operand : cotangents) {
- FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
if (failed(opBuffer)) {
return failure();
}
@@ -402,6 +404,7 @@ struct ForwardOpInterface
FailureOr getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
+ const bufferization::BufferizationState &state,
SmallVector &invocationStack) const
{
// The getBufferType() method is called on either BlockArguments or OpResults.
@@ -426,11 +429,12 @@ struct ForwardOpInterface
return getBufferizedFunctionArgType(forwardOp, bbArg.getArgNumber(), options);
}
- return bufferization::detail::defaultGetBufferType(value, options, invocationStack);
+ return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto forwardOp = cast(op);
FunctionType funcType = forwardOp.getFunctionType();
@@ -451,7 +455,7 @@ struct ForwardOpInterface
// 1. Bufferize every block.
for (Block &block : forwardOp.getBody()) {
- if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) {
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) {
return failure();
}
}
@@ -471,11 +475,12 @@ struct ForwardOpInterface
// Note: If `inferFunctionResultLayout = true`, cast are later folded
// away.
- BaseMemRefType resultType = options.unknownTypeConverterFn(
- returnVal, *options.defaultMemorySpaceFn(tensorType), options);
- Value toMemrefOp =
- rewriter.create(loc, resultType, returnVal);
- returnValues.push_back(toMemrefOp);
+ BaseMemRefType resultType =
+ options.unknownTypeConverterFn(cast(returnVal.getType()),
+ *options.defaultMemorySpaceFn(tensorType), options);
+ Value toBufferOp =
+ rewriter.create(loc, resultType, returnVal);
+ returnValues.push_back(toBufferOp);
}
// 3. Rewrite the terminator.
@@ -523,6 +528,7 @@ struct ReverseOpInterface
FailureOr getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
+ const bufferization::BufferizationState &state,
SmallVector &invocationStack) const
{
// See comment on the getBufferType() method on forward op.
@@ -534,11 +540,12 @@ struct ReverseOpInterface
return getBufferizedFunctionArgType(reverseOp, bbArg.getArgNumber(), options);
}
- return bufferization::detail::defaultGetBufferType(value, options, invocationStack);
+ return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto reverseOp = cast(op);
FunctionType funcType = reverseOp.getFunctionType();
@@ -559,7 +566,7 @@ struct ReverseOpInterface
// 1. Bufferize every block.
for (Block &block : reverseOp.getBody())
- if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options)))
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
return failure();
// 2. For each result, keep track of which inplace argument it reuses.
@@ -577,11 +584,12 @@ struct ReverseOpInterface
// Note: If `inferFunctionResultLayout = true`, cast are later folded
// away.
- BaseMemRefType resultType = options.unknownTypeConverterFn(
- returnVal, *options.defaultMemorySpaceFn(tensorType), options);
- Value toMemrefOp =
- rewriter.create(loc, resultType, returnVal);
- returnValues.push_back(toMemrefOp);
+ BaseMemRefType resultType =
+ options.unknownTypeConverterFn(cast(returnVal.getType()),
+ *options.defaultMemorySpaceFn(tensorType), options);
+ Value toBufferOp =
+ rewriter.create(loc, resultType, returnVal);
+ returnValues.push_back(toBufferOp);
}
// 3. Rewrite the terminator.
diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
index e3a472945e..4e2c4f188e 100644
--- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
@@ -49,8 +49,8 @@ using namespace catalyst::gradient;
namespace catalyst {
namespace gradient {
-void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
- RewriterBase &rewriter, Location loc, bool volatileArgs = false)
+LogicalResult wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
+ RewriterBase &rewriter, Location loc, bool volatileArgs = false)
{
MLIRContext *ctx = rewriter.getContext();
auto ptrType = LLVM::LLVMPointerType::get(ctx);
@@ -59,7 +59,9 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
for (const auto [idx, argType] : llvm::enumerate(func.getArgumentTypes())) {
if (auto memrefType = dyn_cast(argType)) {
BlockArgument memrefArg = func.getArgument(idx);
- func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc);
+ if (failed(func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc))) {
+ return failure();
+ }
Value wrappedMemref = func.getArgument(idx);
Type structType = typeConverter->convertType(memrefType);
@@ -78,9 +80,12 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
rewriter.create(loc, argType, replacedMemref)
.getResult(0);
memrefArg.replaceAllUsesWith(replacedMemref);
- func.eraseArgument(memrefArg.getArgNumber());
+ if (failed(func.eraseArgument(memrefArg.getArgNumber()))) {
+ return failure();
+ }
}
}
+ return success();
}
void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConverter,
@@ -171,16 +176,19 @@ LLVM::GlobalOp insertEnzymeCustomGradient(OpBuilder &builder, ModuleOp moduleOp,
/// functions where MemRefs are passed via wrapped pointers (!llvm.ptr)
/// rather than having their fields unpacked. This function automatically transforms MemRef
/// arguments of a function to wrapped pointers.
-void wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter, RewriterBase &rewriter,
- Location loc, bool volatileArgs = false)
+LogicalResult wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter,
+ RewriterBase &rewriter, Location loc, bool volatileArgs = false)
{
if (llvm::none_of(func.getArgumentTypes(),
[](Type argType) { return isa(argType); })) {
// The memref arguments are already wrapped
- return;
+ return success();
+ }
+ if (failed(wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs))) {
+ return failure();
}
- wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs);
wrapMemRefArgsCallsites(func, typeConverter, rewriter, loc, volatileArgs);
+ return success();
}
} // namespace gradient
} // namespace catalyst
@@ -290,7 +298,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr());
assert(callee && "Expected a valid callee of type func.func");
- catalyst::convertToDestinationPassingStyle(callee, rewriter);
+ if (failed(catalyst::convertToDestinationPassingStyle(callee, rewriter))) {
+ return failure();
+ }
SymbolTableCollection symbolTable;
catalyst::traverseCallGraph(callee, &symbolTable, [&](func::FuncOp func) {
// Register custom gradients of quantum functions
@@ -304,9 +314,11 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
if (!func->hasAttr("unwrapped_type")) {
func->setAttr("unwrapped_type", TypeAttr::get(func.getFunctionType()));
}
- catalyst::convertToDestinationPassingStyle(func, rewriter);
-
- wrapMemRefArgs(func, getTypeConverter(), rewriter, loc, /*volatileArgs=*/true);
+ LogicalResult dpsr = catalyst::convertToDestinationPassingStyle(func, rewriter);
+ assert(dpsr.succeeded() && "failed to rewrite backpropOp to destination style");
+ LogicalResult wmar = wrapMemRefArgs(func, getTypeConverter(), rewriter, loc,
+ /*volatileArgs=*/true);
+ assert(wmar.succeeded() && "failed to wrap backpropOp's memref args");
func::FuncOp augFwd = genAugmentedForward(func, rewriter);
func::FuncOp customQGrad =
@@ -318,10 +330,10 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
LowerToLLVMOptions options = getTypeConverter()->getOptions();
if (options.useGenericFunctions) {
- LLVM::LLVMFuncOp allocFn =
- LLVM::lookupOrCreateGenericAllocFn(moduleOp, getTypeConverter()->getIndexType())
- .value();
- LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(moduleOp).value();
+ LLVM::LLVMFuncOp allocFn = LLVM::lookupOrCreateGenericAllocFn(
+ rewriter, moduleOp, getTypeConverter()->getIndexType())
+ .value();
+ LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(rewriter, moduleOp).value();
// Register the previous functions as llvm globals (for Enzyme)
// With the following piece of metadata, shadow memory is allocated with
@@ -862,7 +874,10 @@ struct ForwardOpPattern : public ConvertOpToLLVMPattern {
func->setAttr("passthrough", ArrayAttr::get(ctx, passthrough));
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
- catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc());
+ if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter,
+ op.getLoc()))) {
+ return failure();
+ }
rewriter.eraseOp(op);
return success();
}
@@ -884,16 +899,12 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
auto params = op.getArguments();
for (size_t i = 0; i < argc * 2; i++) {
- bool isDup = (i % 2) != 0;
- Value val = params[i];
- isDup ? differentials.push_back(val) : inputs.push_back(val);
+ fillValueAndShadowWithDedup(i, params, differentials, inputs);
}
auto upperLimit = (argc * 2) + (resc * 2);
for (size_t i = argc * 2; i < upperLimit; i++) {
- bool isDup = (i % 2) != 0;
- Value val = params[i];
- isDup ? cotangents.push_back(val) : outputs.push_back(val);
+ fillValueAndShadowWithDedup(i, params, cotangents, outputs);
}
auto tapeCount = op.getTape();
@@ -903,16 +914,7 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
}
SmallVector newFuncInputTys;
-
- for (auto [in, diff] : llvm::zip(inputs, differentials)) {
- newFuncInputTys.push_back(in.getType());
- newFuncInputTys.push_back(diff.getType());
- }
-
- for (auto [out, cotan] : llvm::zip(outputs, cotangents)) {
- newFuncInputTys.push_back(out.getType());
- newFuncInputTys.push_back(cotan.getType());
- }
+ getNewFuncInputTys(inputs, outputs, differentials, cotangents, newFuncInputTys);
SmallVector tapeStructs;
auto converter = getTypeConverter();
@@ -986,11 +988,40 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
Block &firstBlock = func.getRegion().getBlocks().front();
Block &lastBlock = func.getRegion().getBlocks().back();
rewriter.mergeBlocks(&lastBlock, &firstBlock);
- catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc());
+ if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter,
+ op.getLoc()))) {
+ return failure();
+ }
rewriter.eraseOp(op);
return success();
}
+
+ private:
+ static void getNewFuncInputTys(const SmallVector &inputs,
+ const SmallVector &outputs,
+ const SmallVector &differentials,
+ const SmallVector &cotangents,
+ SmallVector &newFuncInputTys)
+ {
+ for (auto [in, diff] : llvm::zip(inputs, differentials)) {
+ newFuncInputTys.push_back(in.getType());
+ newFuncInputTys.push_back(diff.getType());
+ }
+
+ for (auto [out, cotan] : llvm::zip(outputs, cotangents)) {
+ newFuncInputTys.push_back(out.getType());
+ newFuncInputTys.push_back(cotan.getType());
+ }
+ }
+
+ static void fillValueAndShadowWithDedup(size_t i, ValueRange params, SmallVector &values,
+ SmallVector &shadows)
+ {
+ bool isDup = (i % 2) != 0;
+ Value val = params[i];
+ isDup ? shadows.push_back(val) : values.push_back(val);
+ }
};
struct ReturnOpPattern : public ConvertOpToLLVMPattern {
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
index 464ab29089..5c24252a1e 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
@@ -148,7 +148,8 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func:
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(&splitFn.getBody().front());
Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount);
- Value paramsTensor = rewriter.create(loc, paramsBuffer, true);
+ Value paramsTensor = rewriter.create(
+ loc, memref::getTensorTypeFromMemRefType(paramsBuffer.getType()), paramsBuffer, true);
qnodeQuantumArgs.push_back(paramsTensor);
MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType());
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
index 66fbcee25d..853ecd1e82 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "Gradient/Utils/DifferentialQNode.h"
@@ -163,13 +164,14 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l
auto tensorTy = diffArg.getType();
auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout(
cast(tensorTy));
- auto toMemrefOp =
- rewriter.create(loc, memrefTy, diffArg);
+ auto toBufferOp =
+ rewriter.create(loc, memrefTy, diffArg);
- auto cloneOp = rewriter.create(loc, toMemrefOp);
+ auto cloneOp = rewriter.create(loc, toBufferOp);
- auto toTensorOp =
- rewriter.create(loc, cloneOp, true);
+ auto toTensorOp = rewriter.create(
+ loc, memref::getTensorTypeFromMemRefType(cloneOp.getOutput().getType()),
+ cloneOp, true);
auto diffArgCopy = toTensorOp.getResult();
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
index 5db5c4a149..75ad0a61a0 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
@@ -86,9 +86,9 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va
: activeResultType);
Value zero = builder.create(
- loc, APFloat(elementType.getFloatSemantics(), 0), elementType);
- Value one = builder.create(
- loc, APFloat(elementType.getFloatSemantics(), 1), elementType);
+ loc, elementType, APFloat(elementType.getFloatSemantics(), 0));
+ Value one = builder.create(loc, elementType,
+ APFloat(elementType.getFloatSemantics(), 1));
Value zeroTensor;
if (auto activeResultTensor = dyn_cast(activeResultType)) {
@@ -397,7 +397,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc,
}
else {
jacobians.push_back(rewriter.create(
- loc, APFloat(0.0), cast(jacobianType)));
+ loc, cast(jacobianType), APFloat(0.0)));
}
}
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
index ce91f8ec80..7227a3d35a 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
@@ -35,7 +35,7 @@ static Value genSelectiveShift(PatternRewriter &rewriter, Location loc, Value pa
}
// Make sure all active iteration variables match the selectors.
- Value shiftCondition = rewriter.create(loc, true, 1);
+ Value shiftCondition = rewriter.create(loc, 1, true);
for (auto &[iteration, selector] : selectors) {
Value iterationMatch =
rewriter.create(loc, arith::CmpIPredicate::eq, iteration, selector);
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
index ff6d172908..d5cb0c117e 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
@@ -59,7 +59,8 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo
{
constexpr double shift = llvm::numbers::pi / 2;
ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type());
- Value selectorVector = rewriter.create(loc, selectorBuffer, true);
+ Value selectorVector = rewriter.create(
+ loc, memref::getTensorTypeFromMemRefType(selectorBuffer.getType()), selectorBuffer, true);
// Define the shift vectors (pos/neg) as sparse tensor constants.
DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift);
@@ -285,8 +286,9 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter,
std::vector gradientTensors;
gradientTensors.reserve(gradResTypes.size());
for (Value gradientBuffer : gradientBuffers) {
- gradientTensors.push_back(
- rewriter.create(loc, gradientBuffer, true));
+ gradientTensors.push_back(rewriter.create(
+ loc, memref::getTensorTypeFromMemRefType(gradientBuffer.getType()),
+ gradientBuffer, true));
}
op->setOperands(gradientTensors);
}
diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
index 70f623fc29..d34dd89974 100644
--- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
@@ -108,7 +108,9 @@ struct PostprocessForwardOp : public OpRewritePattern {
// Insert new argIn in an interleaving way.
size_t idx = 0;
for (auto ty : newArgInTypes) {
- op.insertArgument(2 * idx + 1, ty, {}, op.getLoc());
+ if (failed(op.insertArgument(2 * idx + 1, ty, {}, op.getLoc()))) {
+ return failure();
+ }
idx++;
}
// Append newArgRes.
@@ -117,8 +119,11 @@ struct PostprocessForwardOp : public OpRewritePattern {
/*values=*/op.getNumArguments());
SmallVector argAttrs{appendingSize};
SmallVector argLocs{appendingSize, op.getLoc()};
- op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs);
+ if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) {
+ return failure();
+ }
op.setFunctionType(forwardTy);
+ return success();
});
op.walk([&](ReturnOp returnOp) {
@@ -195,7 +200,9 @@ struct PostprocessReverseOp : public OpRewritePattern {
// Insert new argIn in an interleaving way.
size_t idx = 0;
for (auto ty : newArgInTypes) {
- op.insertArgument(2 * idx, ty, {}, op.getLoc());
+ if (failed(op.insertArgument(2 * idx, ty, {}, op.getLoc()))) {
+ return failure();
+ }
idx++;
}
// Append newArgRes.
@@ -204,8 +211,11 @@ struct PostprocessReverseOp : public OpRewritePattern {
/*values=*/0);
SmallVector argAttrs{appendingSize};
SmallVector argLocs{appendingSize, op.getLoc()};
- op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs);
+ if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) {
+ return failure();
+ }
op.setFunctionType(reverseTy);
+ return success();
});
op.walk([&](ReturnOp returnOp) {
diff --git a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
index ef6b572583..08c082a316 100644
--- a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
+++ b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
@@ -19,11 +19,11 @@
using namespace mlir;
-void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder)
+LogicalResult catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder)
{
if (callee.getNumResults() == 0) {
// Callee is already in destination-passing style
- return;
+ return success();
}
MLIRContext *ctx = callee.getContext();
@@ -48,7 +48,7 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
if (callee.isDeclaration()) {
// If the function does not have a body, we are done after modifying the function type.
callee.setFunctionType(dpsFunctionType);
- return;
+ return success();
}
// Insert the new output arguments to the function.
@@ -60,7 +60,9 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
// insertArguments modifies the function type, so we need to update the function type *after*
// inserting the arguments.
- callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs);
+ if (failed(callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs))) {
+ return failure();
+ }
callee.setFunctionType(dpsFunctionType);
// Update return sites to copy over the memref that would have been returned to the output.
@@ -83,4 +85,6 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
}
returnOp.getOperandsMutable().assign(nonMemRefReturns);
});
+
+ return success();
}
diff --git a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
index 5768982de5..4d1164c140 100644
--- a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
+++ b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
@@ -44,7 +44,7 @@ Value buildTensorLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper
// Initialize the result tensor
FloatType elementType = cast(resultType.getElementType());
Value zero = builder.create(
- loc, APFloat::getZero(elementType.getFloatSemantics()), elementType);
+ loc, elementType, APFloat::getZero(elementType.getFloatSemantics()));
Value result =
builder.create(loc, resultType.getShape(), resultType.getElementType());
result = builder.create(loc, zero, result).getResult(0);
diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
index 4a0312478e..9498f306e8 100644
--- a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -58,15 +58,16 @@ struct QubitUnitaryOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto qubitUnitaryOp = cast(op);
Location loc = op->getLoc();
auto tensorType = cast(qubitUnitaryOp.getMatrix().getType());
MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto toMemrefOp =
- rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
- auto memref = toMemrefOp.getResult();
+ auto toBufferOp =
+ rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
+ auto memref = toBufferOp.getResult();
bufferization::replaceOpWithNewBufferizedOp(
rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(),
qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(),
@@ -101,15 +102,16 @@ struct HermitianOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto hermitianOp = cast(op);
Location loc = op->getLoc();
auto tensorType = cast(hermitianOp.getMatrix().getType());
MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto toMemrefOp =
- rewriter.create(loc, memrefType, hermitianOp.getMatrix());
- auto memref = toMemrefOp.getResult();
+ auto toBufferOp =
+ rewriter.create(loc, memrefType, hermitianOp.getMatrix());
+ auto memref = toBufferOp.getResult();
auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref,
hermitianOp.getQubits());
bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs());
@@ -143,15 +145,16 @@ struct HamiltonianOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto hamiltonianOp = cast(op);
Location loc = op->getLoc();
auto tensorType = cast(hamiltonianOp.getCoeffs().getType());
MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto toMemrefOp =
- rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
- auto memref = toMemrefOp.getResult();
+ auto toBufferOp =
+ rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
+ auto memref = toBufferOp.getResult();
auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref,
hamiltonianOp.getTerms());
bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs());
@@ -187,7 +190,8 @@ struct SampleOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto sampleOp = cast(op);
Location loc = op->getLoc();
@@ -237,7 +241,8 @@ struct CountsOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto countsOp = cast(op);
Location loc = op->getLoc();
@@ -297,7 +302,8 @@ struct ProbsOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto probsOp = cast(op);
Location loc = op->getLoc();
@@ -350,7 +356,8 @@ struct StateOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto stateOp = cast(op);
Location loc = op->getLoc();
@@ -401,16 +408,17 @@ struct SetStateOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto setStateOp = cast(op);
Location loc = op->getLoc();
auto tensorType = cast(setStateOp.getInState().getType());
MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto toMemrefOp =
- rewriter.create(loc, memrefType, setStateOp.getInState());
- auto memref = toMemrefOp.getResult();
+ auto toBufferOp =
+ rewriter.create(loc, memrefType, setStateOp.getInState());
+ auto memref = toBufferOp.getResult();
auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(),
memref, setStateOp.getInQubits());
bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
@@ -443,16 +451,17 @@ struct SetBasisStateOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const bufferization::BufferizationOptions &options) const
+ const bufferization::BufferizationOptions &options,
+ bufferization::BufferizationState &state) const
{
auto setBasisStateOp = cast(op);
Location loc = op->getLoc();
auto tensorType = cast(setBasisStateOp.getBasisState().getType());
MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto toMemrefOp = rewriter.create(
+ auto toBufferOp = rewriter.create(
loc, memrefType, setBasisStateOp.getBasisState());
- auto memref = toMemrefOp.getResult();
+ auto memref = toBufferOp.getResult();
auto newSetStateOp = rewriter.create(
loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits());
bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
index eaf30e2829..8dbf401c46 100644
--- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
@@ -44,7 +44,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe
}
return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
type, rewriter.create(loc, glb),
- ArrayRef{0, 0}, true);
+ ArrayRef{0, 0},
+ LLVM::GEPNoWrapFlags::inbounds);
}
/**
@@ -80,13 +81,17 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter
auto structType = LLVM::LLVMStructType::getLiteral(ctx, {boolType, sizeType, ptrType, ptrType});
auto modifiersPtr = catalyst::getStaticAlloca(loc, rewriter, structType, 1).getResult();
auto adjointPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
- llvm::ArrayRef{0, 0}, true);
+ llvm::ArrayRef{0, 0},
+ LLVM::GEPNoWrapFlags::inbounds);
auto numControlledPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
- llvm::ArrayRef{0, 1}, true);
- auto controlledWiresPtr = rewriter.create(
- loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 2}, true);
- auto controlledValuesPtr = rewriter.create(
- loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 3}, true);
+ llvm::ArrayRef{0, 1},
+ LLVM::GEPNoWrapFlags::inbounds);
+ auto controlledWiresPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
+ llvm::ArrayRef{0, 2},
+ LLVM::GEPNoWrapFlags::inbounds);
+ auto controlledValuesPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
+ llvm::ArrayRef{0, 3},
+ LLVM::GEPNoWrapFlags::inbounds);
Value ctrlPtr = nullPtr;
Value valuePtr = nullPtr;
@@ -98,13 +103,15 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter
for (int i = 0; static_cast(i) < controlledQubits.size(); i++) {
{
auto itemPtr = rewriter.create(loc, ptrType, ptrType, ctrlPtr,
- llvm::ArrayRef{i}, true);
+ llvm::ArrayRef{i},
+ LLVM::GEPNoWrapFlags::inbounds);
auto qubit = controlledQubits[i];
rewriter.create(loc, qubit, itemPtr);
}
{
auto itemPtr = rewriter.create(loc, ptrType, boolType, valuePtr,
- llvm::ArrayRef{i}, true);
+ llvm::ArrayRef{i},
+ LLVM::GEPNoWrapFlags::inbounds);
auto value = controlledValues[i];
rewriter.create(loc, value, itemPtr);
}
@@ -1012,7 +1019,7 @@ struct SetStateOpPattern : public OpConversionPattern {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
ModuleOp moduleOp = op->getParentOfType();
- auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetState",
+ auto func = mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetState",
{ptrTy, i64}, voidTy, isVarArg)
.value();
@@ -1052,9 +1059,10 @@ struct SetBasisStateOpPattern : public OpConversionPattern {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
ModuleOp moduleOp = op->getParentOfType();
- auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState",
- {ptrTy, i64}, voidTy, isVarArg)
- .value();
+ auto func =
+ mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetBasisState",
+ {ptrTy, i64}, voidTy, isVarArg)
+ .value();
SmallVector args;
diff --git a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
index 4c91fbf67c..f4b8a2f9ca 100644
--- a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
+++ b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
@@ -214,13 +214,9 @@ struct EmitCatalystPyInterfacePass
patterns.add(context);
GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
- config.maxIterations = 1;
- // TODO: Update to the following lines the next time we update llvm
- // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
- // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
- // config.setMaxIterations(1);
+ config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+ config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+ config.setMaxIterations(1);
auto op = getOperation();
SmallVector targets;
diff --git a/mlir/llvm-project b/mlir/llvm-project
index 179d30f8c3..f8cb7987c6 160000
--- a/mlir/llvm-project
+++ b/mlir/llvm-project
@@ -1 +1 @@
-Subproject commit 179d30f8c3fddd3c85056fd2b8e877a4a8513158
+Subproject commit f8cb7987c64dcffb72414a40560055cb717dbf74
diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo
index 617a9361d1..1dd2e71331 160000
--- a/mlir/mlir-hlo
+++ b/mlir/mlir-hlo
@@ -1 +1 @@
-Subproject commit 617a9361d186199480c080c9e8c474a5e30c22d1
+Subproject commit 1dd2e71331014ae0373f6bf900ce6be393357190
diff --git a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
index 9f80c60a75..8746e54c73 100644
--- a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
+++ b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
@@ -1,8 +1,8 @@
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-index 85050315..414318eb 100644
+index 7c234dd4..846f68b4 100644
--- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
+++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-@@ -3940,14 +3940,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
+@@ -3942,14 +3942,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
case Intrinsic::nearbyint:
case Intrinsic::round:
case Intrinsic::sqrt:
diff --git a/mlir/patches/llvm-bufferization-segfault.patch b/mlir/patches/llvm-bufferization-segfault.patch
new file mode 100644
index 0000000000..e894516820
--- /dev/null
+++ b/mlir/patches/llvm-bufferization-segfault.patch
@@ -0,0 +1,27 @@
+diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+index 453ed43bcad..dff994729a4 100644
+--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+@@ -89,16 +89,12 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
+ /// Return the FuncOp called by `callOp`.
+ static FuncOp getCalledFunction(CallOpInterface callOp,
+ const AnalysisState &state) {
+- auto &oneShotAnalysisState = static_cast(state);
+-
+- if (auto *funcAnalysisState =
+- oneShotAnalysisState.getExtension()) {
+- // Use the cached symbol tables.
+- return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+- }
+-
+- SymbolTableCollection symbolTables;
+- return getCalledFunction(callOp, symbolTables);
++ SymbolRefAttr sym =
++ llvm::dyn_cast_if_present(callOp.getCallableForCallee());
++ if (!sym)
++ return nullptr;
++ return dyn_cast_or_null(
++ SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ }
+
+ /// Get FuncAnalysisState.
diff --git a/mlir/patches/mhlo-add-back-necessary-passes.patch b/mlir/patches/mhlo-add-back-necessary-passes.patch
index b56ede8dd5..c430adae43 100644
--- a/mlir/patches/mhlo-add-back-necessary-passes.patch
+++ b/mlir/patches/mhlo-add-back-necessary-passes.patch
@@ -7,12 +7,12 @@ Subject: [PATCH] restore the removed mhlo passes we need:
---
mhlo/transforms/CMakeLists.txt | 6 +
.../legalize_control_flow.cc | 288 +++++++++
- .../transforms/legalize_sort/legalize_sort.cc | 577 ++++++++++++++++++
+ .../transforms/legalize_sort/legalize_sort.cc | 578 ++++++++++++++++++
.../legalize_to_standard.cc | 243 ++++++++
.../legalize_to_standard_patterns.td | 92 +++
mhlo/transforms/mhlo_passes.td | 19 +
mhlo/transforms/passes.h | 4 +
- 7 files changed, 1229 insertions(+)
+ 7 files changed, 1230 insertions(+)
create mode 100644 mhlo/transforms/legalize_control_flow/legalize_control_flow.cc
create mode 100644 mhlo/transforms/legalize_sort/legalize_sort.cc
create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard.cc
@@ -342,7 +342,7 @@ new file mode 100644
index 00000000..8ba9de9a
--- /dev/null
+++ b/mhlo/transforms/legalize_sort/legalize_sort.cc
-@@ -0,0 +1,577 @@
+@@ -0,0 +1,578 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
@@ -883,8 +883,9 @@ index 00000000..8ba9de9a
+
+ SmallVector outputTensors;
+ for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) {
++ Value s = b.create(parity, out1, out0).getResult();
+ outputTensors.push_back(b.create(
-+ b.create(parity, out1, out0), /*restrict=*/true));
++ memref::getTensorTypeFromMemRefType(s.getType()), s, /*restrict=*/true));
+ }
+
+ rewriter.replaceOp(op, outputTensors);
diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch
index f78200bdab..32ce71061f 100644
--- a/mlir/patches/mhlo-remove-shardy.patch
+++ b/mlir/patches/mhlo-remove-shardy.patch
@@ -84,7 +84,7 @@ index cabd6a9f..2e64b4ed 100644
patterns->add(context);
patterns->add(context);
patterns->add(context);
-- populateSdyShapeRefinementPatterns(patterns, context);
+- populateSdyShapeRefinementPatterns(context, patterns);
};
if (failed(stablehlo::refineEntryFunction(*context, func,
@@ -92,7 +92,7 @@ index cabd6a9f..2e64b4ed 100644
patterns->add(context);
patterns->add(context);
patterns->add(context);
-- populateSdyShapeRefinementPatterns(patterns, context);
+- populateSdyShapeRefinementPatterns(context, patterns);
}
} // namespace stablehlo_ext
diff --git a/mlir/patches/mhlo-rename-sort.patch b/mlir/patches/mhlo-rename-sort.patch
new file mode 100644
index 0000000000..c356cc35e3
--- /dev/null
+++ b/mlir/patches/mhlo-rename-sort.patch
@@ -0,0 +1,15 @@
+diff --git a/utils/cycle_detector.cc b/utils/cycle_detector.cc
+index e3901ae88..890f39654 100644
+--- a/utils/cycle_detector.cc
++++ b/utils/cycle_detector.cc
+@@ -199,8 +199,8 @@ static void backwardDfs(GraphCycles::Rep* r, int32_t n, int32_t lowerBound) {
+ // Recomputes rank assignments to make them compatible with the edges (producer
+ // has smaller rank than its consumer)
+ static void reorder(GraphCycles::Rep* r) {
+- sort(r->nodes, &r->deltab);
+- sort(r->nodes, &r->deltaf);
++ mlir::sort(r->nodes, &r->deltab);
++ mlir::sort(r->nodes, &r->deltaf);
+
+ // Adds contents of delta lists to list (backwards deltas first).
+ r->list.clear();
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 6394246da0..e3a8b603eb 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
# TODO: Add an upstream cmake param for this vs having a global here.
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir_quantum.")
+# Ignore nanobind warnings
+add_compile_options(-w)
################################################################################
# Declare Dialect Sources
diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir
index effc229a64..67f186943d 100644
--- a/mlir/test/Catalyst/BufferizationTest.mlir
+++ b/mlir/test/Catalyst/BufferizationTest.mlir
@@ -23,7 +23,7 @@
func.func @dbprint_val(%arg0: tensor) {
- // CHECK: %0 = bufferization.to_memref %arg0
+ // CHECK: %0 = bufferization.to_buffer %arg0
// CHECK: "catalyst.print"(%0) : (memref) -> ()
"catalyst.print"(%arg0) : (tensor) -> ()
@@ -34,7 +34,7 @@ func.func @dbprint_val(%arg0: tensor) {
func.func @dbprint_memref(%arg0: tensor) {
- // CHECK: %0 = bufferization.to_memref %arg0
+ // CHECK: %0 = bufferization.to_buffer %arg0
// CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> ()
"catalyst.print"(%arg0) {print_descriptor} : (tensor) -> ()
@@ -54,7 +54,7 @@ func.func @dbprint_str() {
// -----
func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> {
- // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64>
// CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} :
// CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> ()
@@ -72,7 +72,7 @@ func.func @custom_call_copy(%arg0: tensor<2x3xf64>) -> tensor<2x2xf64> {
// COM: e.g. coming from tensor subviews
// COM: a copy needs to be performed because the kernels only allow for contiguous arrays as inputs
//
- // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[subview:%.+]] = memref.subview [[sourceAlloc]]
// CHECK-SAME: memref<2x3xf64> to memref<2x2xf64, strided<[3, 1]>>
// CHECK: [[copyAlloc:%.+]] = memref.alloc() : memref<2x2xf64>
@@ -106,7 +106,7 @@ module @test1 {
// CHECK-LABEL: @foo(
// CHECK-SAME: [[arg0:%.+]]: tensor)
func.func private @foo(%arg0: tensor) -> tensor {
- // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : tensor to memref
+ // CHECK-DAG: [[memref0:%.+]] = bufferization.to_buffer [[arg0]] : tensor to memref
// CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref
// CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> ()
%1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor)
diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir
index 20ce7c6bc7..9a9fc17102 100644
--- a/mlir/test/Catalyst/ConversionTest.mlir
+++ b/mlir/test/Catalyst/ConversionTest.mlir
@@ -146,13 +146,13 @@ module @test1 {
// CHECK-SAME: [[arg0:%.+]]: tensor
// CHECK-SAME:)
func.func private @foo(%arg0: tensor) -> tensor {
- // CHECK: [[memref0:%.+]] = bufferization.to_memref [[arg0]]
+ // CHECK: [[memref0:%.+]] = bufferization.to_buffer [[arg0]]
// CHECK: [[ptr0:%.+]] = llvm.alloca {{.*}}
// CHECK: [[ptr1:%.+]] = llvm.alloca {{.*}}
// CHECK: [[struct0:%.+]] = builtin.unrealized_conversion_cast [[memref0]]
// CHECK: [[tensor1:%.+]] = bufferization.alloc_tensor()
- // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]]
+ // CHECK: [[memref1:%.+]] = bufferization.to_buffer [[tensor1]]
// CHECK: [[struct1:%.+]] = builtin.unrealized_conversion_cast [[memref1]]
// CHECK: llvm.store [[struct0]], [[ptr1]]
@@ -160,9 +160,9 @@ module @test1 {
// call @callback_1([[ptr0]], [[ptr1]])
- %0 = bufferization.to_memref %arg0 : tensor to memref
+ %0 = bufferization.to_buffer %arg0 : tensor to memref
%1 = bufferization.alloc_tensor() {memory_space = 0 : i64} : tensor
- %2 = bufferization.to_memref %1 : tensor to memref
+ %2 = bufferization.to_buffer %1 : tensor to memref
catalyst.callback_call @callback_1(%0, %2) : (memref, memref) -> ()
diff --git a/mlir/test/Gradient/BufferizationTest.mlir b/mlir/test/Gradient/BufferizationTest.mlir
index 4a8f9a246e..8e84995888 100644
--- a/mlir/test/Gradient/BufferizationTest.mlir
+++ b/mlir/test/Gradient/BufferizationTest.mlir
@@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
// CHECK-LABEL: @adjoint_with_tensor_arg
func.func @adjoint_with_tensor_arg(%arg0: tensor<2xf64>, %arg1: index) {
- // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
+ // CHECK: [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64>
// CHECK: [[alloc:%.+]] = memref.alloc(%arg1) : memref
// CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref) : (memref<2xf64>) -> ()
%grad = gradient.adjoint @circuit(%arg0) size(%arg1) : (tensor<2xf64>) -> tensor
@@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
// CHECK-LABEL: @adjoint_with_multiple_results
func.func @adjoint_with_multiple_results(%arg0: tensor<2xf64>, %arg1: index) {
- // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
+ // CHECK: [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64>
// CHECK: [[alloc0:%.+]] = memref.alloc(%arg1) : memref
// CHECK: [[alloc1:%.+]] = memref.alloc(%arg1) : memref
// CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]]
@@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64)
// CHECK-LABEL: @backprop_scalar_in
func.func @backprop_scalar_in(%arg0: f64, %arg1: tensor) {
- // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref
+ // CHECK: [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref
// CHECK: [[dim1:%.+]] = memref.dim [[cotangentSource]]
// CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref
// CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor)
// CHECK-LABEL: @backprop_tensor_in
func.func @backprop_tensor_in(%arg0: tensor, %arg1: tensor) {
- // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor to memref
- // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref
+ // CHECK-DAG: [[argSource:%.+]] = bufferization.to_buffer %arg0 : tensor to memref
+ // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref
// CHECK: [[dim2:%.+]] = memref.dim [[cotangentSource]]
// CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref
// CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>)
// CHECK-LABEL: @backprop_multiple_tensors_in
func.func @backprop_multiple_tensors_in(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>, %arg2: tensor) {
- // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64>
- // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64>
+ // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_buffer %arg0 : tensor<10xf64> to memref<10xf64>
+ // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_buffer %arg1 : tensor<2xf64> to memref<2xf64>
// CHECK: memref.alloc
// CHECK: memref.copy
// CHECK: [[argShadow1:%.+]] = memref.alloc() : memref<10xf64>
@@ -171,8 +171,8 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor, ten
// CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
// CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
- // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref
- // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
+ // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref
+ // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
// CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref, memref<2xf64>
%0:2 = func.call @callback_fn_fwd(%arg0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
@@ -192,7 +192,7 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor, %arg1: tensor<2xf64>)
// CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64>
// CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref
// CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor) -> tensor<2xf64>
- // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
+ // CHECK: [[res:%.+]] = bufferization.to_buffer [[callOut]] : tensor<2xf64> to memref<2xf64>
// CHECK: gradient.return {empty = true} [[res]] : memref<2xf64>
%0 = func.call @callback_fn_vjp(%arg1, %arg0) : (tensor<2xf64>, tensor) -> tensor<2xf64>
diff --git a/mlir/test/Gradient/FiniteDifferenceTest.mlir b/mlir/test/Gradient/FiniteDifferenceTest.mlir
index 13af9e7956..37d70471d6 100644
--- a/mlir/test/Gradient/FiniteDifferenceTest.mlir
+++ b/mlir/test/Gradient/FiniteDifferenceTest.mlir
@@ -161,7 +161,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
// CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
// CHECK: [[DIFF:%.+]] = tensor.generate
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index):
- // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
@@ -188,7 +188,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
// CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
// CHECK: [[DIFF:%.+]] = tensor.generate
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index):
- // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
@@ -227,7 +227,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at
// CHECK: [[BASE:%.+]]:2 = call @funcMultiRes(%arg0)
// CHECK: [[DIFF:%.+]] = tensor.generate
// CHECK-NEXT: ^bb0(%arg1: index):
- // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg1]
@@ -239,7 +239,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at
// CHECK: [[R0:%.+]] = arith.divf [[DIFF]]
// CHECK: [[DIFF:%.+]] = tensor.generate
// CHECK-NEXT: ^bb0(%arg1: index, %arg2: index):
- // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg2]
@@ -279,7 +279,7 @@ func.func private @funcDynamicTensor(%arg0: tensor) -> tensor<2x?xf64>
// CHECK: [[DIFF:%.+]] = tensor.generate [[DDIM0]], [[DDIM1]]
// CHECK-NEXT: ^bb0([[i0:%.+]]: index, [[i1:%.+]]: index, [[i2:%.+]]: index, [[i3:%.+]]: index):
- // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][[[i2]], [[i3]]]
diff --git a/mlir/test/Gradient/PostProcessingTest.mlir b/mlir/test/Gradient/PostProcessingTest.mlir
index 2403372410..9ae25800f1 100644
--- a/mlir/test/Gradient/PostProcessingTest.mlir
+++ b/mlir/test/Gradient/PostProcessingTest.mlir
@@ -25,15 +25,15 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: memref<2xf64>) -> (memref, mem
// CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
// CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
- // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref
- // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
+ // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref
+ // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
// CHECK: memref.copy [[res0]], %arg2 : memref to memref
// CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64>
%0 = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
%1:2 = func.call @callback_fn_fwd(%0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
- %2 = bufferization.to_memref %1#0 : tensor to memref
- %3 = bufferization.to_memref %1#1 : tensor<2xf64> to memref<2xf64>
+ %2 = bufferization.to_buffer %1#0 : tensor