11import multiprocessing
22import os
33import shutil
4- from collections import namedtuple
54
65import torch
76
87import triton
98import triton .language as tl
9+ from triton .compiler import ASTSource
1010
1111tmpdir = ".tmp"
1212
@@ -17,32 +17,26 @@ def reset_tmp_dir():
1717 shutil .rmtree (tmpdir , ignore_errors = True )
1818
1919
20- instance_descriptor = namedtuple ("instance_descriptor" ,
21- ["divisible_by_16" , "equal_to_1" , "ids_of_folded_args" , "divisible_by_8" ])
22-
23-
24- def compile_fn (config , cc ):
20+ def compile_fn (attrs , capability ):
2521
2622 @triton .jit
2723 def kernel_sub (a , b , o , N : tl .constexpr ):
2824 idx = tl .arange (0 , N )
2925 tl .store (o + idx , tl .load (a + idx ) - tl .load (b + idx ) * 777 )
3026
31- triton . compile (
27+ src = ASTSource (
3228 fn = kernel_sub ,
33- signature = {0 : "*fp32" , 1 : "*fp32" , 2 : "*fp32" },
34- device = 0 ,
3529 constants = {3 : 32 },
36- configs = [config ],
37- warm_cache_only = True ,
38- cc = cc ,
30+ signature = {0 : "*fp32" , 1 : "*fp32" , 2 : "*fp32" },
31+ attrs = attrs ,
3932 )
33+ triton .compile (src = src , target = ("cuda" , capability ))
4034
4135
4236def test_compile_in_subproc () -> None :
4337 major , minor = torch .cuda .get_device_capability (0 )
4438 cc = major * 10 + minor
45- config = instance_descriptor (tuple (range (4 )), (), (), ())
39+ config = triton . compiler . AttrsDescriptor (tuple (range (4 )), (), (), ())
4640
4741 multiprocessing .set_start_method ('fork' )
4842 proc = multiprocessing .Process (target = compile_fn , args = (config , cc ))
@@ -51,7 +45,7 @@ def test_compile_in_subproc() -> None:
5145 assert proc .exitcode == 0
5246
5347
54- def compile_fn_dot (config , cc ):
48+ def compile_fn_dot (attrs , capability ):
5549
5650 @triton .jit
5751 def kernel_dot (Z ):
@@ -60,24 +54,18 @@ def kernel_dot(Z):
6054 z = tl .dot (z , z )
6155 tl .store (Z + offs , z )
6256
63- triton .compile (
64- fn = kernel_dot ,
65- signature = {0 : "*fp32" },
66- device = 0 ,
67- configs = [config ],
68- warm_cache_only = True ,
69- cc = cc ,
70- )
57+ src = ASTSource (fn = kernel_dot , signature = {0 : "*fp32" }, attrs = attrs , constants = dict ())
58+ triton .compile (src = src , target = ("cuda" , capability ))
7159
7260
7361def test_compile_in_forked_subproc () -> None :
7462 reset_tmp_dir ()
7563 major , minor = torch .cuda .get_device_capability (0 )
76- cc = major * 10 + minor
77- config = instance_descriptor (tuple (range (1 )), (), (), ())
64+ capability = major * 10 + minor
65+ config = triton . compiler . AttrsDescriptor (tuple (range (1 )), (), (), ())
7866
7967 assert multiprocessing .get_start_method () == 'fork'
80- proc = multiprocessing .Process (target = compile_fn_dot , args = (config , cc ))
68+ proc = multiprocessing .Process (target = compile_fn_dot , args = (config , capability ))
8169 proc .start ()
8270 proc .join ()
8371 assert proc .exitcode == 0
0 commit comments