From 0ff7a35cb35a6799376564c1f73e30252cb75df4 Mon Sep 17 00:00:00 2001 From: Joshua James Venter Date: Wed, 3 Dec 2025 17:16:56 +0200 Subject: [PATCH] [FRONTEND] Add get_int_attr to Operation for out-of-tree walk Enables extracting integer attributes from MLIR operations during out-of-tree walks. Useful for extracting constant values as well as attributes like start and end from tt.make_range, for example. Follows the pattern established in PR #3191. Signed-off-by: Joshua James Venter --- python/src/ir.cc | 7 +++++++ python/test/unit/runtime/test_bindings.py | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index bf662affa343..d4e0619cca3c 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -635,6 +635,13 @@ void init_triton_ir(py::module &&m) { return py::none(); return py::str(ret.getValue().str()); }) + .def("get_int_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) .def("get_bool_attr", [](Operation &self, const std::string &name) -> py::object { auto ret = self.getAttrOfType(name); diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index 6b28cfe3db10..de9c1dc9c7fb 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -4,6 +4,8 @@ import torch import math +_BLOCK_SIZE = 16 + @triton.jit def add_helper(x, y): @@ -50,6 +52,12 @@ def walk_fn(op): op.get_str_attr("sym_name") if name == "tt.call": op.get_flat_symbol_ref_attr("callee") + if name == "tt.make_range": + assert 0 == op.get_int_attr("start") + assert _BLOCK_SIZE == op.get_int_attr("end") + if name == "arith.constant": + val = op.get_int_attr("value") + assert val is None or isinstance(val, int) kernel = add_kernel args = [ @@ -57,7 +65,7 @@ def walk_fn(op): torch.empty((32, 32), device=device), # in_ptr1 1024, # n_elements torch.empty((32, 32), device=device), # out_ptr - 16, # BLOCK_SIZE + _BLOCK_SIZE, # BLOCK_SIZE ] target = triton.runtime.driver.active.get_current_target() backend = triton.compiler.compiler.make_backend(target)