Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import triton
import triton.language as tl
from triton.runtime import driver
from triton.runtime.jit import get_current_device


# kernel used to query max clusters for persistent kernel when NUM_CTAS > 1
Expand Down Expand Up @@ -899,10 +898,10 @@ def process_epilogue(d, bias, w, epilogue):

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel = triton.compile(src)
null_kernel._init_handles()
device = driver.get_current_device()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
1)
Expand Down
49 changes: 22 additions & 27 deletions python/triton/compiler/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def ptx_get_version(cuda_version) -> int:
raise RuntimeError("Triton only support CUDA 10.0 or higher")


@dataclass
@dataclass(frozen=True)
class CUDAOptions:
num_warps: int = 4
num_ctas: int = 1
Expand All @@ -52,9 +52,14 @@ class CUDAOptions:
enable_fp_fusion: bool = True
allow_fp8e4nv: bool = False
max_num_imprecise_acc_default: bool = None
extern_libs: dict = None
debug: bool = False

def __post_init__(self):
# TODO: change API
if isinstance(self.extern_libs, dict):
extern_libs = tuple([(k, v) for k, v in self.extern_libs.items() if v])
object.__setattr__(self, 'extern_libs', extern_libs)
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
"num_warps must be a power of 2"

Expand All @@ -63,30 +68,18 @@ def hash(self):
return hashlib.md5(key.encode("utf-8")).hexdigest()


@dataclass
class CUDALinkerOptions:
libs: dict = None

def __post_init__(self):
if self.libs is not None:
self.libs = {k: v for k, v in self.libs.items() if v}


class CUDABackend(BaseBackend):

def __init__(self, device_type: tuple) -> None:
super().__init__(device_type)
self.capability = device_type[1]
assert isinstance(self.capability, int)

def parse_compiler_options(self, opts) -> Any:
options = CUDAOptions(**opts)
options.allow_fp8e4nv = self.capability >= 89
options.max_num_imprecise_acc_default = 0 if self.capability >= 89 else None
return options

def parse_linker_options(self, opts) -> Any:
return CUDALinkerOptions(**opts)
def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
args["allow_fp8e4nv"] = self.capability >= 89
args["max_num_imprecise_acc_default"] = 0 if self.capability >= 89 else None
return CUDAOptions(**args)

@staticmethod
def make_ttir(mod, metadata, opt):
Expand Down Expand Up @@ -167,13 +160,15 @@ def make_ttgir(mod, metadata, opt, capability):
return mod

@staticmethod
def make_llir(src, metadata, linker_options, capability):
def make_llir(src, metadata, options, capability):
metadata["enable_warp_specialization"] = ir.is_ws_supported(src)
metadata["num_warps"] = get_num_warps(src)
tma_infos = TMAInfos()
# link libraries
if linker_options.libs:
add_external_libs(src, list(linker_options.libs.keys()), list(linker_options.libs.values()))
if options.extern_libs:
names = [lib[0] for lib in options.extern_libs]
paths = [lib[1] for lib in options.extern_libs]
add_external_libs(src, names, paths)
# TritonGPU -> LLVM-IR
ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
if len(tma_infos) > 0:
Expand All @@ -198,12 +193,12 @@ def make_cubin(src, metadata, opt, capability):
ptxas, _ = path_to_ptxas()
return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion)

def add_stages(self, stages, compiler_options, linker_options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, compiler_options)
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, compiler_options, self.capability)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, linker_options, self.capability)
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, compiler_options, self.capability)
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, compiler_options, self.capability)
def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)

def hash(self):
return f'{get_cuda_version_key()}-{self.capability}'
Expand Down
11 changes: 5 additions & 6 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def visit_Call(self, node: ast.Call) -> bool:

class CodeGenerator(ast.NodeVisitor):

def __init__(self, context, prototype, gscope, attributes, constants, function_name, options, module=None,
is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
def __init__(self, context, prototype, gscope, attributes, constants, function_name, options, debug=None,
module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
Expand All @@ -228,7 +228,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
self.function_name = function_name
self.is_kernel = is_kernel
self.last_node = None
self.debug = options.debug
self.debug = options.debug if debug is None else debug
self.noinline = noinline
self.scf_stack = []
self.last_ret_type = None
Expand Down Expand Up @@ -981,12 +981,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
gscope = sys.modules[fn.fn.__module__].__dict__
# If the callee is not set, we use the same debug setting as the caller
file_name, begin_line = _get_fn_file_line(fn)
options = self.builder.options
options.debug = self.debug if fn.debug is None else fn.debug
debug = self.debug if fn.debug is None else fn.debug
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
function_name=fn_name, function_types=self.function_ret_types,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
options=self.builder.options)
options=self.builder.options, debug=debug)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
Expand Down
43 changes: 21 additions & 22 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.jit import get_current_device, get_cuda_stream, get_current_target
from ..runtime.driver import driver
from .utils import InfoFromBackendForTensorMap
from .backends.cuda import CUDABackend
Expand Down Expand Up @@ -121,8 +120,8 @@ def metadata(self):
# TODO: remove once TMA support is cleaned up
return {"ids_of_folded_args": tuple([int(k) for k in self.attrs.ids_of_folded_args])}

def update_options(self, options):
pass
def parse_options(self):
return dict()


class IRSource:
Expand Down Expand Up @@ -150,24 +149,24 @@ def make_ir(self, options):
def metadata(self):
return dict()

def update_options(self, options):
def parse_options(self):
if self.ext == "ttgir":
options.num_warps = _get_num_warps_from_ir_str(self.src)
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
return dict()


def compile(src, target=None, compiler_options=None, linker_options=None):
def compile(src, target=None, options=None):
if target is None:
target = get_current_target()
target = driver.get_current_target()
backend = CUDABackend(target)
# create backend
compiler_options = backend.parse_compiler_options(compiler_options or dict())
linker_options = backend.parse_linker_options(linker_options or dict())
if not isinstance(src, ASTSource):
assert isinstance(src, str), "source must be either AST or a filepath"
src = IRSource(src)
src.update_options(compiler_options)
extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
key = f"{src.hash()}-{backend.hash()}-{compiler_options.hash()}-{frozenset(sorted(get_env_vars().items()))}"
key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{frozenset(sorted(get_env_vars().items()))}"
hash = hashlib.md5(key.encode("utf-8")).hexdigest()
fn_cache_manager = get_cache_manager(hash)
metadata_filename = f"{src.name}.json"
Expand All @@ -181,16 +180,15 @@ def compile(src, target=None, compiler_options=None, linker_options=None):
# initialize metadata
metadata = {
"target": target,
**compiler_options.__dict__,
**linker_options.__dict__,
**options.__dict__,
**get_env_vars(),
**src.metadata(),
}
# run compilation pipeline and populate metadata
stages = dict()
backend.add_stages(stages, compiler_options, linker_options)
backend.add_stages(stages, options)
first_stage = list(stages.keys()).index(src.ext)
module = src.make_ir(compiler_options)
module = src.make_ir(options)
for ext, compile_ir in list(stages.items())[first_stage:]:
next_module = compile_ir(module, metadata)
metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}")
Expand Down Expand Up @@ -218,7 +216,7 @@ def __init__(self, so_path, metadata_path):
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
self.run = getattr(mod, "launch")
# initialize metadata
self.metadata = json.loads(metadata_path.read_text())
self.metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in self.metadata['tensormaps_info']
Expand All @@ -243,7 +241,7 @@ def __init__(self, so_path, metadata_path):
def _init_handles(self):
if self.module is not None:
return
device = get_current_device()
device = driver.get_current_device()
# not enough shared memory to run the kernel
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
Expand All @@ -253,7 +251,7 @@ def _init_handles(self):
self.name, self.kernel, self.shared, device)

def __getattribute__(self, name):
if name == 'c_wrapper':
if name == 'run':
self._init_handles()
return super().__getattribute__(name)

Expand All @@ -263,9 +261,10 @@ def __getitem__(self, grid):
def runner(*args, stream=None):
args_expand = driver.assemble_tensormap_to_arg(self.tensormaps_info, args)
if stream is None:
stream = get_cuda_stream()
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0],
self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
device = driver.get_current_device()
stream = driver.get_current_stream(device)
self.run(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0],
self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)

return runner
18 changes: 18 additions & 0 deletions python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
from pathlib import Path
import functools

from ..common.build import _build
from .cache import get_cache_manager
Expand Down Expand Up @@ -101,6 +102,23 @@ def __init__(self):
self.utils = CudaUtils()
self.backend = self.CUDA
self.binary_ext = "cubin"
# TODO: support other frameworks than torch
import torch
self.get_device_capability = torch.cuda.get_device_capability
try:
from torch._C import _cuda_getCurrentRawStream
self.get_current_stream = _cuda_getCurrentRawStream
except ImportError:
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
self.get_current_device = torch.cuda.current_device
self.set_current_device = torch.cuda.set_device

@functools.lru_cache()
def get_current_target(self):
device = self.get_current_device()
capability = self.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
return ("cuda", capability)

def assemble_tensormap_to_arg(self, tensormaps_info, args):
args_with_tma = list(args)
Expand Down
Loading