Skip to content

Commit 0036b31

Browse files
authored
[FRONTEND] Expand MLIR bindings for out-of-tree walk (#3191)
Summary: This PR adds a few simple MLIR bindings to `ir.cc` to allow walking the MLIR structure of the TTIR module out of tree. This will help making the Triton kernel analysis performed in PyTorch 2 more robust and reliable (related PR in PT2: pytorch/pytorch#120476).
1 parent df8fd02 commit 0036b31

2 files changed

Lines changed: 134 additions & 9 deletions

File tree

python/src/ir.cc

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "triton/Dialect/Triton/IR/Types.h"
2121
#include "triton/Dialect/Triton/IR/Utility.h"
2222
#include "triton/Tools/Sys/GetEnv.hpp"
23+
#include <pybind11/functional.h>
2324
#include <pybind11/pybind11.h>
2425
#include <pybind11/stl.h>
2526

@@ -249,14 +250,22 @@ void init_triton_ir(py::module &&m) {
249250
[](Value &self, Value &newValue) {
250251
self.replaceAllUsesWith(newValue);
251252
})
252-
.def("get_type", &Value::getType);
253+
.def("get_type", &Value::getType)
254+
.def("id", [](Value &self) {
255+
// The Value is identified by and compared with
256+
// other Values via the underlying ValueImpl
257+
return (uint64_t)self.getImpl();
258+
});
259+
260+
py::class_<OpResult, Value>(m, "op_result", py::module_local());
253261

254262
py::class_<BlockArgument, Value>(m, "block_argument", py::module_local());
255263

256264
py::class_<Region>(m, "region", py::module_local())
257265
.def("get_parent_region", &Region::getParentRegion, ret::reference)
258266
.def("size", [](Region &self) { return self.getBlocks().size(); })
259-
.def("empty", &Region::empty);
267+
.def("empty", &Region::empty)
268+
.def("id", [](Region &self) { return (uint64_t)&self; });
260269

261270
py::class_<Block>(m, "block", py::module_local())
262271
.def("arg",
@@ -271,6 +280,7 @@ void init_triton_ir(py::module &&m) {
271280
self.addArgument(ty, loc);
272281
})
273282
.def("get_num_arguments", &Block::getNumArguments)
283+
.def("get_argument", &Block::getArgument)
274284
.def("dump", &Block::dump)
275285
.def("move_before",
276286
[](Block &self, Block &dst) { self.moveBefore(&dst); })
@@ -318,7 +328,8 @@ void init_triton_ir(py::module &&m) {
318328
return !self.empty() &&
319329
self.back().hasTrait<OpTrait::ReturnLike>();
320330
})
321-
.def("erase", [](Block &self) { self.erase(); });
331+
.def("erase", [](Block &self) { self.erase(); })
332+
.def("id", [](Block &self) { return (uint64_t)&self; });
322333

323334
py::class_<Attribute>(m, "attribute", py::module_local());
324335
py::class_<IntegerAttr, Attribute>(m, "integer_attr", py::module_local());
@@ -386,6 +397,35 @@ void init_triton_ir(py::module &&m) {
386397
.def("get_after", &scf::WhileOp::getAfter, ret::reference);
387398
py::class_<scf::ConditionOp, OpState>(m, "ConditionOp", py::module_local());
388399

400+
py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(
401+
m, "operation", py::module_local())
402+
.def("get_name",
403+
[](Operation &self) {
404+
llvm::StringRef opName = self.getName().getStringRef();
405+
return opName.str();
406+
})
407+
.def("get_num_operands", &Operation::getNumOperands)
408+
.def("get_operand", &Operation::getOperand)
409+
.def("get_num_results", &Operation::getNumResults)
410+
.def("get_result", &Operation::getResult)
411+
.def("get_num_regions", &Operation::getNumRegions)
412+
.def("get_region", &Operation::getRegion, ret::reference)
413+
.def("get_block", &Operation::getBlock, ret::reference)
414+
.def("get_str_attr",
415+
[](Operation &self, const std::string &name) -> py::object {
416+
auto ret = self.getAttrOfType<StringAttr>(name);
417+
if (!ret)
418+
return py::none();
419+
return py::str(ret.getValue().str());
420+
})
421+
.def("get_flat_symbol_ref_attr",
422+
[](Operation &self, const std::string &name) -> py::object {
423+
auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
424+
if (!ret)
425+
return py::none();
426+
return py::str(ret.getValue().str());
427+
});
428+
389429
// dynamic_attr is used to transfer ownership of the MLIR context to the
390430
// module
391431
py::class_<ModuleOp, OpState>(m, "module", py::module_local(),
@@ -414,12 +454,17 @@ void init_triton_ir(py::module &&m) {
414454
[](ModuleOp &self, std::string &funcName) -> FuncOp {
415455
return self.lookupSymbol<FuncOp>(funcName);
416456
})
417-
.def("get_int_attr", [](ModuleOp &self, std::string name) -> py::object {
418-
auto ret = self->getAttrOfType<IntegerAttr>(name);
419-
if (!ret)
420-
return py::none();
421-
return py::int_(ret.getInt());
422-
});
457+
.def("get_int_attr",
458+
[](ModuleOp &self, std::string name) -> py::object {
459+
auto ret = self->getAttrOfType<IntegerAttr>(name);
460+
if (!ret)
461+
return py::none();
462+
return py::int_(ret.getInt());
463+
})
464+
.def("walk",
465+
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
466+
self.walk(fn);
467+
});
423468

424469
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
425470
return DenseIntElementsAttr::get(
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import triton
2+
import triton.language as tl
3+
4+
import torch
5+
6+
7+
@triton.jit
8+
def add_helper(x, y):
9+
return x + y
10+
11+
12+
@triton.jit
13+
def add_kernel(
14+
in_ptr0,
15+
in_ptr1,
16+
n_elements,
17+
out_ptr,
18+
BLOCK_SIZE: "tl.constexpr",
19+
):
20+
pid = tl.program_id(axis=0)
21+
block_start = pid * BLOCK_SIZE
22+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
23+
mask = offsets < n_elements
24+
x = tl.load(in_ptr0 + offsets, mask=mask)
25+
y = tl.load(in_ptr1 + offsets, mask=mask)
26+
output = add_helper(x, y)
27+
tl.store(out_ptr + offsets, output, mask=mask)
28+
29+
30+
def test_module_walk():
31+
"""
32+
Test the MLIR bindings exposed for the out-ot-tree walk.
33+
"""
34+
35+
def walk_fn(op):
36+
name = op.get_name()
37+
for i in range(op.get_num_results()):
38+
op.get_result(i).id()
39+
for i in range(op.get_num_operands()):
40+
op.get_operand(i).id()
41+
for i in range(op.get_num_regions()):
42+
op.get_region(i).id()
43+
block = op.get_block()
44+
if block is not None:
45+
block.id()
46+
for i in range(block.get_num_arguments()):
47+
block.get_argument(i)
48+
if name == "tt.func":
49+
op.get_str_attr("sym_name")
50+
if name == "tt.call":
51+
op.get_flat_symbol_ref_attr("callee")
52+
53+
kernel = add_kernel
54+
args = [
55+
torch.empty((32, 32), device="cuda"), # in_ptr0
56+
torch.empty((32, 32), device="cuda"), # in_ptr1
57+
1024, # n_elements
58+
torch.empty((32, 32), device="cuda"), # out_ptr
59+
16, # BLOCK_SIZE
60+
]
61+
src = triton.compiler.compiler.ASTSource(
62+
fn=kernel,
63+
signature={i: kernel._type_of(kernel._key_of(arg))
64+
for i, arg in enumerate(args)
65+
if i not in kernel.constexprs},
66+
constants={i: arg
67+
for i, arg in enumerate(args)
68+
if not isinstance(arg, torch.Tensor)},
69+
attrs=kernel._get_config(*args, ),
70+
)
71+
72+
context = triton._C.libtriton.ir.context()
73+
target = triton.runtime.driver.active.get_current_target()
74+
backend = triton.compiler.compiler.make_backend(target)
75+
options = backend.parse_options(dict())
76+
triton._C.libtriton.ir.load_dialects(context)
77+
backend.load_dialects(context)
78+
79+
ttir_module = src.make_ir(options, context)
80+
ttir_module.walk(walk_fn)

0 commit comments

Comments
 (0)