Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
fi


Integration-Tests:
needs: Runner-Preparation

Expand All @@ -49,7 +48,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: "true"
- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
Expand Down
3 changes: 2 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def build_extension(self, ext):

setup(
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
version="2.2.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
version="2.3.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
author="Philippe Tillet",
author_email="[email protected]",
description="A language and compiler for custom Deep Learning operations",
Expand All @@ -353,6 +353,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/compiler/backends",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
87 changes: 71 additions & 16 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

#include <pybind11/numpy.h>
namespace py = pybind11;
using namespace mlir;

PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy);

Expand Down Expand Up @@ -170,7 +171,7 @@ class TritonOpBuilder {
private:
std::unique_ptr<mlir::OpBuilder> builder;
std::unique_ptr<mlir::Location> lastLoc;
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
bool lineInfoEnabled = !::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};

static std::string locationToString(mlir::Location loc) {
Expand Down Expand Up @@ -347,15 +348,23 @@ void init_triton_ir(py::module &&m) {
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
.def("get_type", &mlir::Value::getType);
.def("get_type", &mlir::Value::getType)
.def("id", [](Value &self) {
// The Value is identified by and compared with
// other Values via the underlying ValueImpl
return (uint64_t)self.getImpl();
});

py::class_<OpResult, Value>(m, "op_result", py::module_local());

py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument",
py::module_local());

py::class_<mlir::Region>(m, "region", py::module_local())
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
.def("size", [](mlir::Region &self) { return self.getBlocks().size(); })
.def("empty", &mlir::Region::empty);
.def("empty", &mlir::Region::empty)
.def("id", [](Region &self) { return (uint64_t)&self; });

py::class_<mlir::Block>(m, "block", py::module_local())
.def("arg",
Expand All @@ -368,6 +377,7 @@ void init_triton_ir(py::module &&m) {
self.addArgument(ty, loc);
})
.def("get_num_arguments", &mlir::Block::getNumArguments)
.def("get_argument", &Block::getArgument)
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
.def("insert_before", &mlir::Block::insertBefore)
Expand Down Expand Up @@ -414,7 +424,8 @@ void init_triton_ir(py::module &&m) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
})
.def("erase", [](mlir::Block &self) { self.erase(); });
.def("erase", [](mlir::Block &self) { self.erase(); })
.def("id", [](Block &self) { return (uint64_t)&self; });

// using eattr = ir::attribute_kind_t;
// py::enum_<eattr>(m, "attribute_kind")
Expand Down Expand Up @@ -461,7 +472,9 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpState &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
self->print(os);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.enableDebugInfo();
self->print(os, printingFlags);
return str;
})
.def("append_operand",
Expand Down Expand Up @@ -489,6 +502,35 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp",
py::module_local());

py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(
m, "operation", py::module_local())
.def("get_name",
[](Operation &self) {
llvm::StringRef opName = self.getName().getStringRef();
return opName.str();
})
.def("get_num_operands", &Operation::getNumOperands)
.def("get_operand", &Operation::getOperand)
.def("get_num_results", &Operation::getNumResults)
.def("get_result", &Operation::getResult)
.def("get_num_regions", &Operation::getNumRegions)
.def("get_region", &Operation::getRegion, ret::reference)
.def("get_block", &Operation::getBlock, ret::reference)
.def("get_str_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<StringAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
})
.def("get_flat_symbol_ref_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
});

// dynamic_attr is used to transfer ownership of the MLIR context to the
// module
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::module_local(),
Expand All @@ -498,7 +540,9 @@ void init_triton_ir(py::module &&m) {
[](mlir::ModuleOp &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.enableDebugInfo();
self.print(os, printingFlags);
return str;
})
.def("bytecode",
Expand Down Expand Up @@ -532,6 +576,17 @@ void init_triton_ir(py::module &&m) {
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
if (!ret)
return py::none();
return py::int_(ret.getInt());
})
.def("walk",
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
self.walk(fn);
});

m.def("make_attr",
Expand Down Expand Up @@ -1685,9 +1740,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createReorderBroadcastPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::triton::createRewriteTensorPointerPass(capability));
})
.def("add_tritongpu_ws_feasibility_checking_pass",
[](mlir::PassManager &self, int computeCapability) {
Expand Down Expand Up @@ -1761,9 +1816,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
})
.def("add_tritongpu_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::createTritonGPURewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::createTritonGPURewriteTensorPointerPass(capability));
})
.def("add_tritongpu_decompose_conversions_pass",
[](mlir::PassManager &self) {
Expand Down Expand Up @@ -1794,8 +1849,8 @@ void init_triton_ir(py::module &&m) {
void init_triton_env_vars(py::module &m) {
m.def("get_env_vars", []() -> std::map<std::string, bool> {
std::map<std::string, bool> envVars;
for (const auto &envVar : triton::ENV_VARS) {
envVars[envVar] = triton::tools::getBoolEnv(envVar);
for (const auto &envVar : ::triton::ENV_VARS) {
envVars[envVar] = ::triton::tools::getBoolEnv(envVar);
}
return envVars;
});
Expand Down Expand Up @@ -1896,7 +1951,7 @@ void init_triton_translation(py::module &m) {
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode = triton::translateLLVMIRToPTX(*module, capability,
auto ptxCode = ::triton::translateLLVMIRToPTX(*module, capability,
version, enable_fp_fusion);
return ptxCode;
},
Expand Down Expand Up @@ -1925,7 +1980,7 @@ void init_triton_translation(py::module &m) {
ofs.close();

auto lineInfoOption =
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
? ""
: " -lineinfo";
auto fmadOption = enable_fp_fusion ? "" : " --fmad=false";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import triton
import triton.language as tl
from triton.runtime import driver
from triton.runtime.jit import get_current_device


# kernel used to query max clusters for persistent kernel when NUM_CTAS > 1
Expand Down Expand Up @@ -899,12 +898,13 @@ def process_epilogue(d, bias, w, epilogue):

NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel = triton.compile(src)
null_kernel._init_handles()
device = driver.get_current_device()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
1)
NUM_SMS = num_clusters

def grid(META):
Expand Down
82 changes: 82 additions & 0 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import triton
import triton.language as tl
from triton.compiler.backends.cuda import CUDABackend
from triton.runtime.driver import driver

import torch


@triton.jit
def add_helper(x, y):
return x + y


@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = add_helper(x, y)
tl.store(out_ptr + offsets, output, mask=mask)


def test_module_walk():
"""
Test the MLIR bindings exposed for the out-ot-tree walk.
"""

def walk_fn(op):
name = op.get_name()
for i in range(op.get_num_results()):
op.get_result(i).id()
for i in range(op.get_num_operands()):
op.get_operand(i).id()
for i in range(op.get_num_regions()):
op.get_region(i).id()
block = op.get_block()
if block is not None:
block.id()
for i in range(block.get_num_arguments()):
block.get_argument(i)
if name == "tt.func":
op.get_str_attr("sym_name")
if name == "tt.call":
op.get_flat_symbol_ref_attr("callee")

kernel = add_kernel
args = [
torch.empty((32, 32), device="cuda"), # in_ptr0
torch.empty((32, 32), device="cuda"), # in_ptr1
1024, # n_elements
torch.empty((32, 32), device="cuda"), # out_ptr
16, # BLOCK_SIZE
]
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={i: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs},
constants={i: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=kernel._get_config(*args, ),
)

triton._C.libtriton.ir = triton._C.libtriton.triton.ir
context = triton._C.libtriton.ir.context()

target = driver.get_current_target()
backend = CUDABackend(target)
options = backend.parse_options(dict())

ttir_module = src.make_ir(options)
ttir_module.walk(walk_fn)
Loading