Skip to content

Commit 29f1955

Browse files
committed
[FRONTEND] Add get_int_attr to Operation for out-of-tree walk
Enables extracting integer attributes from MLIR operations during out-of-tree walks. Useful for extracting constant values and attributes like start/end from tt.make_range. Follows the pattern established in PR triton-lang#3191. Signed-off-by: Joshua James Venter <venter.joshua@gmail.com>
1 parent 024d809 commit 29f1955

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

python/src/ir.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,13 @@ void init_triton_ir(py::module &&m) {
635635
return py::none();
636636
return py::str(ret.getValue().str());
637637
})
638+
.def("get_int_attr",
639+
[](Operation &self, const std::string &name) -> py::object {
640+
auto ret = self.getAttrOfType<IntegerAttr>(name);
641+
if (!ret)
642+
return py::none();
643+
return py::int_(ret.getInt());
644+
})
638645
.def("get_bool_attr",
639646
[](Operation &self, const std::string &name) -> py::object {
640647
auto ret = self.getAttrOfType<BoolAttr>(name);

python/test/unit/runtime/test_bindings.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import math
66

7+
_BLOCK_SIZE = 16
8+
79

810
@triton.jit
911
def add_helper(x, y):
@@ -50,14 +52,20 @@ def walk_fn(op):
5052
op.get_str_attr("sym_name")
5153
if name == "tt.call":
5254
op.get_flat_symbol_ref_attr("callee")
55+
if name == "tt.make_range":
56+
assert 0 == op.get_int_attr("start")
57+
assert _BLOCK_SIZE == op.get_int_attr("end")
58+
if name == "arith.constant":
59+
val = op.get_int_attr("value")
60+
assert val is None or isinstance(val, int)
5361

5462
kernel = add_kernel
5563
args = [
5664
torch.empty((32, 32), device=device), # in_ptr0
5765
torch.empty((32, 32), device=device), # in_ptr1
5866
1024, # n_elements
5967
torch.empty((32, 32), device=device), # out_ptr
60-
16, # BLOCK_SIZE
68+
_BLOCK_SIZE, # BLOCK_SIZE
6169
]
6270
target = triton.runtime.driver.active.get_current_target()
6371
backend = triton.compiler.compiler.make_backend(target)

0 commit comments

Comments
 (0)