Skip to content

Commit 7862420

Browse files
ptilletshunting314
authored andcommitted
[FRONTEND] refactor compiler submodule (triton-lang#2701)
1 parent 9d8f96c commit 7862420

File tree

19 files changed

+526
-760
lines changed

19 files changed

+526
-760
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ jobs:
3434
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
3535
fi
3636
37-
3837
Integration-Tests:
3938
needs: Runner-Preparation
4039

@@ -49,7 +48,7 @@ jobs:
4948
- name: Checkout
5049
uses: actions/checkout@v3
5150
with:
52-
submodules: 'true'
51+
submodules: "true"
5352
- name: Set CUDA ENV
5453
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
5554
run: |

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def build_extension(self, ext):
353353
"triton/_C",
354354
"triton/common",
355355
"triton/compiler",
356+
"triton/compiler/backends",
356357
"triton/language",
357358
"triton/language/extra",
358359
"triton/ops",

python/src/triton.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,9 @@ void init_triton_ir(py::module &&m) {
16851685
self.addPass(mlir::triton::createReorderBroadcastPass());
16861686
})
16871687
.def("add_rewrite_tensor_pointer_pass",
1688-
[](mlir::PassManager &self, int computeCapability) {
1689-
self.addPass(mlir::triton::createRewriteTensorPointerPass(
1690-
computeCapability));
1688+
[](mlir::PassManager &self, int capability) {
1689+
self.addPass(
1690+
mlir::triton::createRewriteTensorPointerPass(capability));
16911691
})
16921692
.def("add_tritongpu_ws_feasibility_checking_pass",
16931693
[](mlir::PassManager &self, int computeCapability) {
@@ -1761,9 +1761,9 @@ void init_triton_ir(py::module &&m) {
17611761
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
17621762
})
17631763
.def("add_tritongpu_rewrite_tensor_pointer_pass",
1764-
[](mlir::PassManager &self, int computeCapability) {
1765-
self.addPass(mlir::createTritonGPURewriteTensorPointerPass(
1766-
computeCapability));
1764+
[](mlir::PassManager &self, int capability) {
1765+
self.addPass(
1766+
mlir::createTritonGPURewriteTensorPointerPass(capability));
17671767
})
17681768
.def("add_tritongpu_decompose_conversions_pass",
17691769
[](mlir::PassManager &self) {

python/test/unit/hopper/test_persistent_warp_specialized_gemm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,12 @@ def process_epilogue(d, bias, w, epilogue):
900900
NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
901901
if NUM_CTAS > 1:
902902
device = get_current_device()
903-
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
903+
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
904+
null_kernel = triton.compile(src)
904905
null_kernel._init_handles()
905906
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
906-
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
907-
1, 1)
907+
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
908+
1)
908909
NUM_SMS = num_clusters
909910

910911
def grid(META):
Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import multiprocessing
22
import os
33
import shutil
4-
from collections import namedtuple
54

65
import torch
76

87
import triton
98
import triton.language as tl
9+
from triton.compiler import ASTSource
1010

1111
tmpdir = ".tmp"
1212

@@ -17,32 +17,26 @@ def reset_tmp_dir():
1717
shutil.rmtree(tmpdir, ignore_errors=True)
1818

1919

20-
instance_descriptor = namedtuple("instance_descriptor",
21-
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
22-
23-
24-
def compile_fn(config, cc):
20+
def compile_fn(attrs, capability):
2521

2622
@triton.jit
2723
def kernel_sub(a, b, o, N: tl.constexpr):
2824
idx = tl.arange(0, N)
2925
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
3026

31-
triton.compile(
27+
src = ASTSource(
3228
fn=kernel_sub,
33-
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
34-
device=0,
3529
constants={3: 32},
36-
configs=[config],
37-
warm_cache_only=True,
38-
cc=cc,
30+
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
31+
attrs=attrs,
3932
)
33+
triton.compile(src=src, target=("cuda", capability))
4034

4135

4236
def test_compile_in_subproc() -> None:
4337
major, minor = torch.cuda.get_device_capability(0)
4438
cc = major * 10 + minor
45-
config = instance_descriptor(tuple(range(4)), (), (), ())
39+
config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ())
4640

4741
multiprocessing.set_start_method('fork')
4842
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
@@ -51,7 +45,7 @@ def test_compile_in_subproc() -> None:
5145
assert proc.exitcode == 0
5246

5347

54-
def compile_fn_dot(config, cc):
48+
def compile_fn_dot(attrs, capability):
5549

5650
@triton.jit
5751
def kernel_dot(Z):
@@ -60,24 +54,18 @@ def kernel_dot(Z):
6054
z = tl.dot(z, z)
6155
tl.store(Z + offs, z)
6256

63-
triton.compile(
64-
fn=kernel_dot,
65-
signature={0: "*fp32"},
66-
device=0,
67-
configs=[config],
68-
warm_cache_only=True,
69-
cc=cc,
70-
)
57+
src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict())
58+
triton.compile(src=src, target=("cuda", capability))
7159

7260

7361
def test_compile_in_forked_subproc() -> None:
7462
reset_tmp_dir()
7563
major, minor = torch.cuda.get_device_capability(0)
76-
cc = major * 10 + minor
77-
config = instance_descriptor(tuple(range(1)), (), (), ())
64+
capability = major * 10 + minor
65+
config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ())
7866

7967
assert multiprocessing.get_start_method() == 'fork'
80-
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
68+
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability))
8169
proc.start()
8270
proc.join()
8371
assert proc.exitcode == 0

python/test/unit/tools/test_aot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def test_ttgir_to_ptx():
446446
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
447447
with open(kernel_path, "w") as fp:
448448
fp.write(src)
449-
k = triton.compile(kernel_path, cc=80)
449+
k = triton.compile(kernel_path, target=("cuda", 80))
450450
ptx = k.asm["ptx"]
451451
assert ".target sm_80" in ptx
452452
assert ".address_size 64" in ptx

python/triton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from . import language
2323
from . import testing
24+
from . import tools
2425

2526
__all__ = [
2627
"autotune",

python/triton/compiler/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
2-
instance_descriptor)
1+
from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor)
32
from .errors import CompilationError
43

54
__all__ = [
6-
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
5+
"compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
76
"get_arch_default_num_stages"
87
]
File renamed without changes.

0 commit comments

Comments
 (0)