diff --git a/python/src/ir.cc b/python/src/ir.cc index 8ed9b0e66acc..e87aef8f8745 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -20,6 +20,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include #include #include @@ -249,14 +250,22 @@ void init_triton_ir(py::module &&m) { [](Value &self, Value &newValue) { self.replaceAllUsesWith(newValue); }) - .def("get_type", &Value::getType); + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); py::class_(m, "block_argument", py::module_local()); py::class_(m, "region", py::module_local()) .def("get_parent_region", &Region::getParentRegion, ret::reference) .def("size", [](Region &self) { return self.getBlocks().size(); }) - .def("empty", &Region::empty); + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); py::class_(m, "block", py::module_local()) .def("arg", @@ -271,6 +280,7 @@ void init_triton_ir(py::module &&m) { self.addArgument(ty, loc); }) .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) .def("dump", &Block::dump) .def("move_before", [](Block &self, Block &dst) { self.moveBefore(&dst); }) @@ -318,7 +328,8 @@ void init_triton_ir(py::module &&m) { return !self.empty() && self.back().hasTrait(); }) - .def("erase", [](Block &self) { self.erase(); }); + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); py::class_(m, "attribute", py::module_local()); py::class_(m, "integer_attr", py::module_local()); @@ -386,6 +397,35 @@ void init_triton_ir(py::module &&m) { .def("get_after", &scf::WhileOp::getAfter, ret::reference); py::class_(m, "ConditionOp", py::module_local()); + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + // dynamic_attr is used to transfer ownership of the MLIR context to the // module py::class_(m, "module", py::module_local(), @@ -414,12 +454,17 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, std::string &funcName) -> FuncOp { return self.lookupSymbol(funcName); }) - .def("get_int_attr", [](ModuleOp &self, std::string name) -> py::object { - auto ret = self->getAttrOfType(name); - if (!ret) - return py::none(); - return py::int_(ret.getInt()); - }); + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); m.def("make_attr", [](const std::vector &values, MLIRContext &context) { return DenseIntElementsAttr::get( diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000000..23c867a597ec --- /dev/null +++ b/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,80 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(): + """ + Test the MLIR bindings exposed for the out-ot-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device="cuda"), # in_ptr0 + torch.empty((32, 32), device="cuda"), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device="cuda"), # out_ptr + 16, # BLOCK_SIZE + ] + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs}, + constants={i: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=kernel._get_config(*args, ), + ) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, context) + ttir_module.walk(walk_fn)