diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 155c2824e..2928cbc14 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -155,10 +155,12 @@ def visit_Return(self, node: ast.Return) -> bool: def visit_Assign(self, node: ast.Assign) -> bool: # There couldn't be an early return + # x = ... return False def visit_AugAssign(self, node: ast.AugAssign) -> bool: # There couldn't be an early return + # x += ... return False def visit_Module(self, node: ast.Module) -> bool: @@ -168,6 +170,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: return self._visit_stmts(node.body) def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return ret = self._visit_stmts(node.body) if node.orelse: ret = ret or self._visit_stmts(node.orelse) @@ -192,6 +201,9 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module @@ -474,7 +486,10 @@ def visit_AnnAssign(self, node): return self.visit_Assign(node) def visit_Assign(self, node): - # flagtree: First, do normal assignment processing + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("anno_CodeGenerator_visit_Assign") + _names = [] if isinstance(node, ast.AnnAssign): _names += [self.visit(node.target)] @@ -498,30 +513,10 @@ def visit_Assign(self, node): not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) - - # flagtree: After normal processing, check if we need to add hint annotation - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and - isinstance(node.value, ast.Call) and - isinstance(node.value.func, ast.Attribute) and - isinstance(node.value.func.value, ast.Name) and - node.value.func.value.id == 'tl' and - node.value.func.attr == 'load'): - - # Add hint annotation to the loaded tensor(s) - for name, value in zip(names, values): - if _is_triton_value(value): - # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") - # Create hint annotation - hint_val = self.builder.get_unit_attr() - self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) def visit_AugAssign(self, node): name = node.target.id @@ -828,6 +823,8 @@ def visit_While(self, node): liveins, insert_block = sr ip, last_loc = self._get_insertion_point_and_loc() + # loop body (the after region) + # loop_block = self.builder.create_block() dummy = self.builder.create_block() self.builder.set_insertion_point_to_start(dummy) self.scf_stack.append(node) @@ -921,8 +918,11 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None - bind_sub_block = None - if IteratorClass in [language.range, language.parallel]: + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + bind_sub_block = flagtree_backend_specialization("init_bind_sub_block") + if IteratorClass in [language.range] + ([language.parallel] if flagtree_backend_specialization("is_visit_For_support_parallel") else []): iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -932,8 +932,9 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor - if (IteratorClass is language.parallel): - bind_sub_block = iterator.bind_sub_block + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block) elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -943,20 +944,10 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') - - # flagtree: After normal processing, check if we need to override bind_sub_block - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a range/for loop with bind_sub_block hint - if flagtree_hints and 'bind_sub_block' in flagtree_hints: - bind_sub_block = True - # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") - + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("check_override_bind_sub_block", self, node, bind_sub_block) + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1021,7 +1012,8 @@ def visit_For(self, node): if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) if (bind_sub_block is not None) and bind_sub_block: - for_op.set_attr("bind_sub_block", self.builder.get_bool_attr(bind_sub_block)) + #flagtree backend specialization + flagtree_backend_specialization("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) @@ -1105,7 +1097,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): generator.visit(fn.parse()) except Exception as e: # Wrap the error in the callee with the location of the call. - raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, self.cur_node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e callee_ret_type = generator.ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1149,7 +1145,11 @@ def visit_Call(self, node): # itself). But when calling a function, we raise as `from e` to # preserve the traceback of the original error, which may e.g. # be in core.py. - raise CompilationError(self.jit_fn.src, node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6a8359d6f..a7bc26bd2 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -4,7 +4,6 @@ from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor -from ..backends.ascend.compiler import AscendAttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager @@ -12,7 +11,6 @@ from ..tools.disasm import get_sass # TODO: this shouldn't be here from .code_generator import ast_to_ttir -from .errors import MLIRCompilationError from pathlib import Path import re import functools @@ -87,8 +85,9 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.constants.keys(): if not isinstance(k, str): raise TypeError("Constants keys must be string") - if self.attrs is None: - self.attrs = AscendAttrsDescriptor() + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_ASTSource_attrs", self) def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] @@ -252,12 +251,11 @@ def compile(src, target=None, options=None): # cache hit! metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) - compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') - if (compile_speed_opt): - ttir_path = f"{file_name}.ttir" - if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): - # Already compile once but failed. So directly return - raise Exception("already failed once") + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("opt_ascend_compile_speed", file_name, metadata_path, fn_cache_manager) + # initialize metadata metadata = { "hash": hash, @@ -287,14 +285,10 @@ def compile(src, target=None, options=None): try: next_module = compile_ir(module, metadata) except Exception as e: - if (ext == "ttadapter"): - stage_name = "ConvertTritonIRToLinalgIR" - elif (ext == "npubin"): - stage_name = "ConvertLinalgRToBinary" - else: - stage_name = "MLIRCompile" - error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - raise MLIRCompilationError(stage_name, error_detail) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("handle_compile_error", e, ext) + ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): print(f"\nOverriding kernel with file {full_name}") @@ -406,9 +400,12 @@ def _init_handles(self): # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) - - # This mechanism introduces heavy runtime overhead. - # Commenting __getattribute__ requires explicitly calling _init_handles() + def __getattribute__(self, name): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if name == 'run' and flagtree_backend_specialization("is_CompiledKernel_getattribute_need_init_handles"): + self._init_handles() + return super().__getattribute__(name) def launch_metadata(self, grid, stream, *args): if CompiledKernel.launch_enter_hook is None: @@ -431,8 +428,11 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): - if stream is None: - stream = self.metadata.stream + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("set_CompiledKernel_metadata_stream", self, stream) + if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index 5242258ad..39e6c4dfb 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -49,20 +49,3 @@ class CompileTimeAssertionFailure(CompilationError): class UnsupportedLanguageConstruct(CompilationError): pass - - -class MLIRCompilationError(TritonError): - def __init__(self, stage_name: Optional[str], message: Optional[str] = None): - self.stage_name = stage_name - self.message = f"\n" \ - f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ - f"[{self.stage_name}] encounters error:\n" \ - f"{self.filter_message(message)}" \ - f"{self.format_line_delim('[ERROR][Triton][END]')}" - def __str__(self): - return self.message - def filter_message(self, message): - # Content starting from "Stack dump without symbol names" means nothing to the users - return message.split("Stack dump without symbol names")[0] - def format_line_delim(self, keyword): - return f"///------------------{keyword}------------------\n" \ No newline at end of file diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6502a5348..773a52737 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -105,6 +105,29 @@ view, void, where, + gather, + get_element, + insert_slice, + extract_slice, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort ) from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, ceil) @@ -152,6 +175,7 @@ "cdiv", "ceil", "clamp", + "compile_hint", "const", "constexpr", "cos", @@ -164,11 +188,13 @@ "dot", "dot_scaled", "dtype", + "dtype_to_ir", "erf", "exp", "exp2", "expand_dims", "extra", + "extract_slice", "fdiv", "flip", "float16", @@ -183,8 +209,11 @@ "fma", "full", "function_type", + "gather", + "get_element", "histogram", "inline_asm_elementwise", + "insert_slice", "interleave", "int1", "int16", @@ -194,9 +223,11 @@ "ir", "join", "load", + "load_tensor_descriptor", "log", "log2", "make_block_ptr", + "make_tensor_descriptor", "math", "max", "max_constancy", @@ -204,9 +235,11 @@ "maximum", "min", "minimum", + "multibuffer", "multiple_of", "num_programs", "pair_uniform_to_normal", + "parallel", "permute", "philox", "philox_impl", @@ -236,8 +269,12 @@ "static_print", "static_range", "store", + "store_tensor_descriptor", "sum", "swizzle2d", + "sync_block_all", + "sync_block_set", + "sync_block_wait", "tensor", "trans", "triton", @@ -253,6 +290,14 @@ "xor_sum", "zeros", "zeros_like", + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__lshift__", + "__rshift__" ] diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py index d0ca8c734..b89037db2 100644 --- a/python/triton/language/_utils.py +++ b/python/triton/language/_utils.py @@ -1,39 +1,23 @@ -from __future__ import annotations +from typing import List -from typing import List, TYPE_CHECKING, Any, Union, Dict +TRITON_MAX_TENSOR_NUMEL = 1048576 -if TYPE_CHECKING: - from .language import core - IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] - ObjPath = tuple[int, ...] -TRITON_MAX_TENSOR_NUMEL = 1048576 +def is_power_of_two(x): + return (x & (x - 1)) == 0 + def validate_block_shape(shape: List[int]): numel = 1 for i, d in enumerate(shape): if not isinstance(d, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_block_shape_check_power_of_two') and not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") numel *= d if numel > TRITON_MAX_TENSOR_NUMEL: raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") return numel - - -BITWIDTH_DICT: Dict[str, int] = { - **{f"u{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"i{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"fp{n}": n - for n in (16, 32, 64)}, - **{f"fp8{suffix}": 8 - for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, - "bf16": 16, - "void": 0, -} - - -def get_primitive_bitwidth(dtype: str) -> int: - return BITWIDTH_DICT[dtype] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388..95dfee352 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1292,6 +1292,11 @@ def trans(input: tensor, *dims, _builder=None): """ if not dims: dims = (1, 0) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + dims = flagtree_backend_specialization('ext_trans_unwrap_iterable', dims) + return semantic.permute(input, dims, _builder) @@ -1482,7 +1487,7 @@ def expand_dims(input, axis, _builder=None): @_tensor_member_fn @builtin -def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, overflow_mode: Optional[str] = None, _builder=None): """ Casts a tensor to the given :code:`dtype`. @@ -1497,13 +1502,25 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas :param bitcast: If true, the tensor is bitcasted to the given :code:`dtype`, instead of being numerically casted. :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional """ + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + overflow_modes = flagtree_backend_specialization('ext_cast_set_overflow_modes') + input = semantic.to_tensor(input, _builder) if isinstance(bitcast, constexpr): bitcast = bitcast.value if bitcast: return semantic.bitcast(input, dtype, _builder) - return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + # flagtree backend specialization + ret = semantic.cast(input, dtype, _builder, fp_downcast_rounding) + flagtree_backend_specialization('ext_cast_check_overflow_mode', overflow_mode, overflow_modes, ret, _builder) + return ret # ----------------------- @@ -1537,11 +1554,18 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i specified (i.e. at least one must be :code:`None`). """ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_dot_deprecated_param_allow_tf32', allow_tf32) + if input_precision is None: supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) - + else: + # flagtree backend specialization + flagtree_backend_specialization('check_dot_invalid_input_precision', input_precision) input_precision = _constexpr_to_value(input_precision) out_dtype = _constexpr_to_value(out_dtype) max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) @@ -1659,6 +1683,35 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): return semantic.descriptor_store(desc_pointer, value, offsets, _builder) +@builtin +def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], + _builder=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_load_tensor_descriptor', desc, offsets, _builder) + + +@builtin +def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, + _builder=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_store_tensor_descriptor', desc, offsets, value, _builder) + + +@builtin +def make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> tensor_descriptor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_make_tensor_descriptor', base, shape, strides, block_shape, _builder) + + @_tensor_member_fn @builtin def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): @@ -1737,6 +1790,57 @@ def advance(base, offsets, _builder=None): return semantic.advance(base, offsets, _builder) +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_insert_slice', ful, sub, offsets, sizes, strides, _builder, _generator) + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_extract_slice', ful, offsets, sizes, strides, _builder, _generator) + + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_get_element', src, indice, _builder, _generator) + + +@builtin +def multibuffer(src: tensor, size, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_multibuffer', src, size, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_all', mode, event_id, _builder) + + +@builtin +def sync_block_set(sender, receiver, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_set', sender, receiver, event_id, _builder) + + +@builtin +def sync_block_wait(sender, receiver, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_wait', sender, receiver, event_id, _builder) + # ----------------------- # Atomic Memory Operations # ----------------------- @@ -1989,6 +2093,61 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No return semantic.clamp(x, min, max, propagate_nan, _builder) +@builtin +def __add__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_add', self, other, _builder) + + +@builtin +def __radd__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_radd', self, other, _builder) + + +@builtin +def __sub__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_sub', self, other, _builder) + + +@builtin +def __rsub__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rsub', self, other, _builder) + + +@builtin +def __mul__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_mul', self, other, _builder) + + +@builtin +def __rmul__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rmul', self, other, _builder) + + +@builtin +def __lshift__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_lshift', self, other, _builder) + + +@builtin +def __rshift__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rshift', self, other, _builder) + # ----------------------- # Reductions # ----------------------- @@ -2108,6 +2267,12 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None return rvalue, rindices +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_sort', ptr, dim, descending, _builder) + # ----------------------- # Scans # ----------------------- @@ -2184,6 +2349,13 @@ def histogram(input, num_bins, _builder=None, _generator=None): return semantic.histogram(input, num_bins, _builder) +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_gather', src, index, axis, _builder) + # ----------------------- # Compiler Hint Ops # ----------------------- @@ -2256,6 +2428,12 @@ def assume(cond, _builder=None): return semantic.assume(semantic.to_tensor(cond, _builder), _builder) +@builtin +def compile_hint(ptr, hint_name, hint_val=None, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_compile_hint', ptr, hint_name, hint_val, _builder) + # ----------------------- # Debugging functions # ----------------------- @@ -2588,6 +2766,21 @@ def __next__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") +# flagtree backend specialization +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. + This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of + iteration in this loop. Currently on 910B, max 2 vector cores could be used. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, bind_sub_block: bool = False): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + self.bind_sub_block = bind_sub_block + # ----------------------- # Extern functions # ----------------------- @@ -2692,3 +2885,9 @@ def binary_op_type_legalization(lhs, rhs, builder): def extern(fn): """A decorator for external functions.""" return builtin(fn) + + +def dtype_to_ir(self, builder: ir.builder) -> ir.type: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_dtype_to_ir', self, builder) diff --git a/python/triton/language/core_ext.py b/python/triton/language/core_ext.py index f44af20c2..f3f4154eb 100644 --- a/python/triton/language/core_ext.py +++ b/python/triton/language/core_ext.py @@ -20,7 +20,7 @@ mul, ) from typing import Optional -from . import semantic_ext as semantic +from . import semantic from .tensor_descriptor import tensor_descriptor, tensor_descriptor_base @@ -128,6 +128,7 @@ def dot( ) +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def gather(src, index, axis, _builder=None): @@ -143,6 +144,7 @@ def gather(src, index, axis, _builder=None): return semantic.gather(src, index, axis, _builder) +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: @@ -170,6 +172,7 @@ def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=No return out +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: @@ -193,6 +196,7 @@ def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) sub = semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) return sub +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def get_element(src, indice, _builder=None, _generator=None): @@ -213,31 +217,38 @@ def get_element(src, indice, _builder=None, _generator=None): ] return semantic.get_element(src, new_indice, _builder) +# FIXME: non-exist in core.py @builtin def __add__(self, other, _builder=None): return add(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __radd__(self, other, _builder=None): return add(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __sub__(self, other, _builder=None): return sub(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __rsub__(self, other, _builder=None): return sub(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __mul__(self, other, _builder=None): return mul(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __rmul__(self, other, _builder=None): return mul(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __lshift__(self, other, _builder=None): if self.type.scalar.is_floating(): @@ -247,6 +258,7 @@ def __lshift__(self, other, _builder=None): return semantic.shl(self, other, _builder) +# FIXME: non-exist in core.py @builtin def __rshift__(self, other, _builder=None): if self.type.scalar.is_floating(): @@ -259,6 +271,7 @@ def __rshift__(self, other, _builder=None): return semantic.lshr(self, other, _builder) +# FIXME: non-exist in core.py class parallel(range): """ Iterator that counts upward forever, with parallel execution semantics. @@ -274,6 +287,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact self.bind_sub_block = bind_sub_block +# FIXME: non-exist in core.py @builtin def compile_hint(ptr, hint_name, hint_val=None, _builder=None): def _unwrap(val): @@ -289,6 +303,7 @@ def _unwrap(val): semantic.compile_hint(ptr, hint_name, hint_val, _builder) +# FIXME: non-exist in core.py @builtin def sort(ptr, dim=-1, descending=False, _builder=None): """ @@ -319,7 +334,8 @@ def sort(ptr, dim=-1, descending=False, _builder=None): semantic.compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) return ret - + +# FIXME: non-exist in core.py @builtin def multibuffer(src: tensor, size, _builder=None): """ @@ -332,6 +348,7 @@ def multibuffer(src: tensor, size, _builder=None): semantic.compile_hint(src, "multi_buffer", buffer_size, _builder) +# FIXME: non-exist in core.py @builtin def sync_block_all(mode, event_id, _builder=None): mode = _constexpr_to_value(mode) @@ -342,6 +359,7 @@ def sync_block_all(mode, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) +# FIXME: non-exist in core.py @builtin def sync_block_set(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) @@ -355,6 +373,7 @@ def sync_block_set(sender, receiver, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) +# FIXME: non-exist in core.py @builtin def sync_block_wait(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) @@ -368,6 +387,7 @@ def sync_block_wait(sender, receiver, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) +# FIXME: non-exist in core.py @builtin def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], _builder=None) -> tensor: @@ -375,6 +395,7 @@ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union return desc.load(offsets, _builder=_builder) +# FIXME: non-exist in core.py @builtin def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, _builder=None) -> tensor: @@ -382,6 +403,7 @@ def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Unio return desc.store(offsets, value, _builder=_builder) +# FIXME: non-exist in core.py @builtin def make_tensor_descriptor( base: tensor, @@ -440,6 +462,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) +# FIXME: non-exist in core.py def dtype_to_ir(self, builder: ir.builder) -> ir.type: if self.name.startswith("fp8"): raise ValueError(f'unexpected type fp8.') @@ -475,4 +498,4 @@ def dtype_to_ir(self, builder: ir.builder) -> ir.type: elif self.name == 'fp64': return builder.get_double_ty() raise ValueError(f'fail to convert {self} to ir type') - + diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8e9f87b5e..a7d0e601a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -305,6 +305,12 @@ def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Numbe input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + flagtree_backend_specialization('check_was_bool_to_int8_dtype', other) + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) input = cast(input, ret_ty, builder) @@ -332,6 +338,12 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + flagtree_backend_specialization('check_was_bool_to_int8_dtype', other) + if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other floor = math.floor(fdiv(input, other, False, builder), _builder=builder) @@ -358,6 +370,9 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): x, y = binary_op_type_checking_impl(x, y, builder) dtype = x.dtype + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_bool', dtype) if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) @@ -376,6 +391,9 @@ def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): x, y = binary_op_type_checking_impl(x, y, builder) dtype = x.dtype + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_bool', dtype) if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) @@ -424,37 +442,66 @@ def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_and(input.handle, other.handle), input.type) def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_or(input.handle, other.handle), input.type) def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + other = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', other, builder) if not other.type.is_int1(): other = bitcast(other, tl.dtype("int1"), builder) return and_(input, other, builder) def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + other = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', other, builder) if not other.type.is_int1(): other = bitcast(other, tl.dtype("int1"), builder) return or_(input, other, builder) def not_(input: tl.tensor, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) + if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) return invert(input, builder) @@ -486,6 +533,11 @@ def plus(input: tl.tensor) -> tl.tensor: def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: input_sca_ty = input.type.scalar + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) @@ -493,7 +545,13 @@ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) + input_sca_ty = input.type.scalar + flagtree_backend_specialization('check_unexpected_dtype_float', input) if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) @@ -609,8 +667,11 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: if end <= start: raise ValueError("arange's end argument must be greater than the start argument") range = end - start - if (range & (range - 1)) != 0: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_arange_check_power_of_two') and (range & (range - 1)) != 0: raise ValueError("arange's range must be a power of 2") + flagtree_backend_specialization('check_arange_less_than_max_numel', range) shape = [range] ret_ty = tl.block_type(tl.int32, shape) return tl.tensor(builder.create_make_range(start, end), ret_ty) @@ -842,6 +903,10 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_cast_src_dst_scalar_type_equal', src_sca_ty, dst_sca_ty): + return input # For fp downcasting default rounding mode should be RTNE, for all other conversions it should # not be set @@ -856,6 +921,10 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unsupported_fp8_fp64', src_sca_ty, dst_sca_ty) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): assert builder.codegen_fns.get( "convert_custom_types") is not None, "target doesn't provide conversion for this type." @@ -1076,6 +1145,10 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ "pointers or loading a scalar. Because the compiler does not know the boundary; please " "use block pointers (defined by `make_block_ptr`) instead") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if other is None: + other = flagtree_backend_specialization('set_load_legacy_other_input', builder) # For a pointer of scalar, check the type of `mask` and `other` if not ptr.type.is_block(): if mask and mask.type.is_block(): @@ -1120,8 +1193,10 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile), dst_ty) - if is_bool: + if is_bool and flagtree_backend_specialization('cast_back_when_load_legacy_ptr_is_bool'): ret = cast(ret, tl.int1, builder) + + flagtree_backend_specialization('set_attr_was_bool_to_int8', ret, is_bool) return ret @@ -1289,8 +1364,13 @@ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: sem = _str_to_sem(sem) scope = _str_to_scope(scope) element_ty = ptr.type.scalar.element_ty - if element_ty.primitive_bitwidth not in [16, 32, 64]: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if element_ty.primitive_bitwidth not in [16, 32, 64] and flagtree_backend_specialization('is_atomic_cas_need_element_bitwidth_check'): raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + + flagtree_backend_specialization('ext_atomic_cas_element_typechecking', element_ty) + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) @@ -1301,10 +1381,16 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, if ptr.type.is_const() or ptr.type.element_ty.is_const(): raise ValueError("Cannot store to a constant pointer") element_ty = ptr.type.scalar.element_ty - if element_ty is tl.float16 and op != 'add': + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if element_ty is tl.float16 and op != 'add' and flagtree_backend_specialization('is_atomic_need_original_check'): raise ValueError("atomic_" + op + " does not support fp16") - if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16] and flagtree_backend_specialization('is_atomic_need_original_check'): raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + + # flagtree backend specialization + flagtree_backend_specialization("ext_atomic_element_typechecking", element_ty, op) + if ptr.type.is_block(): if mask is not None: mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) @@ -1334,6 +1420,12 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: else: return tl.tensor( builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_atomic_max_no_bitcast'): + return flagtree_backend_specialization('atomic_max_returning_tensor', ir, ptr, val, mask, sem, scope, builder) + # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 @@ -1373,6 +1465,12 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: else: return tl.tensor( builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_atomic_min_no_bitcast'): + return flagtree_backend_specialization('atomic_min_returning_tensor', ir, ptr, val, mask, sem, scope, builder) + # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 @@ -1463,10 +1561,12 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona # All combinations of supported fp8 x fp8 are permitted pass else: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + tl.float32) + flagtree_backend_specialization('ext_dot_lhs_supported_type'), f"Unsupported lhs dtype {lhs.dtype}" assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + tl.float32) + flagtree_backend_specialization('ext_dot_rhs_supported_type'), f"Unsupported rhs dtype {rhs.dtype}" assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): @@ -1513,6 +1613,10 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona acc_handle = acc.handle assert acc.type == ret_ty + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('dot_check_hf32_input_precision', input_precision, ir, lhs, rhs, ret_scalar_ty) + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 if max_num_imprecise_acc is None: if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): @@ -1520,9 +1624,9 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona else: max_num_imprecise_acc = 0 else: - if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + if flagtree_backend_specialization('is_dot_check_max_num_imprecise_acc') and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") - + max_num_imprecise_acc = flagtree_backend_specialization('reset_dot_max_num_imprecise_acc', max_num_imprecise_acc) return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) @@ -1538,6 +1642,14 @@ def _str_to_fp_type(float_format: Optional[str]): return ir.F8F6F4TY.E3M2 if float_format == 'e2m1': return ir.F8F6F4TY.E2M1 + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if float_format == 'bf16' and flagtree_backend_specialization('is_float_format_support_bf16'): + return ir.F8F6F4TY.BF16 + if float_format == 'fp16' and flagtree_backend_specialization('is_float_format_support_fp16'): + return ir.F8F6F4TY.FP16 + raise ValueError(f"Invalid float format: {float_format}.") @@ -1545,22 +1657,42 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_dot_scaled_validate_lhs_dtype', lhs) + flagtree_backend_specialization('ext_dot_scaled_validate_rhs_dtype', rhs) + flagtree_backend_specialization('ext_dot_scaled_check_same_dtype', lhs, rhs) + lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + assert lhs_format in ("e2m1", "e4m3", "e5m2") or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"NYI: lhs_format {lhs_format}" + assert rhs_format in ("e4m3", "e5m2") or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"NYI: rhs_format {rhs_format}" + flagtree_backend_specialization('ext_dot_scaled_check_lhs_rhs_format', lhs_format, rhs_format) + rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None - assert rhs_scale_is_none, "NYI: rhs_scale not supported" + rhs_scale_is_none = flagtree_backend_specialization('dot_scaled_recheck_rhs_scale_is_none', rhs_scale, rhs_scale_is_none) + lhs_scale_is_none = flagtree_backend_specialization('dot_scaled_check_lhs_scale_is_none', lhs_scale) + assert rhs_scale_is_none or flagtree_backend_specialization('is_dot_scaled_support_rhs_scale'), "NYI: rhs_scale not supported" + + flagtree_backend_specialization('check_dot_scaled_lhs_scale_dtype', lhs_scale) + flagtree_backend_specialization('check_dot_scaled_rhs_scale_dtype', rhs_scale, rhs_scale_is_none) + lhs = flagtree_backend_specialization('dot_scaled_lhs_bitcast_to_fp_type', lhs, lhs_format, builder) + rhs = flagtree_backend_specialization('dot_scaled_rhs_bitcast_to_fp_type', rhs, rhs_format, builder) + + flagtree_backend_specialization('check_dot_scaled_dimension', lhs, rhs) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] PACKED = 2 if lhs_format == "e2m1" else 1 assert K == PACKED * lhs.type.shape[ - -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" - assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + -1] or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), \ + f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + flagtree_backend_specialization('check_dot_scaled_pack_size', PACKED, K, lhs_format, lhs, rhs) + assert K >= 64 or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) @@ -1571,6 +1703,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, acc_handle = acc.handle assert acc.type == ret_ty rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = flagtree_backend_specialization('set_dot_scaled_lhs_scale_handle', lhs_scale, lhs_scale_is_none) return tl.tensor( builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, acc_handle), ret_ty) @@ -1794,3 +1927,69 @@ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: # Advanced block pointer type is the same as before return tl.tensor(builder.create_advance(base.handle, offsets), base.type) + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_gather', src, index, axis, builder) + + +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_insert_slice', ful, sub, offsets, sizes, strides, builder) + + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_extract_slice', ful, offsets, sizes, strides, builder) + + +def get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_get_element', src, indice, builder) + + +def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_semantic_compile_hint", ptr, hint_name, hint_val, builder) + + +def custom_op(builder: ir.builder, op_name: str, **kwargs): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_semantic_custom_op', builder, op_name, **kwargs) + + +def sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_sort', ptr, dim, descending, builder) + + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_scalar_constant', value, dtype, builder) + + +def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_make_scalar', value, dtype, builder) + + +def make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder +) -> tensor_descriptor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_make_tensor_descriptor', base, shape, strides, block_shape, builder) diff --git a/python/triton/language/semantic_ext.py b/python/triton/language/semantic_ext.py index 8ef3e2ec9..e27dbb21e 100644 --- a/python/triton/language/semantic_ext.py +++ b/python/triton/language/semantic_ext.py @@ -44,7 +44,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: return input - + src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar if src_sca_ty == dst_sca_ty: @@ -298,12 +298,12 @@ def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) - +# FIXME: non-exist in semantic.py def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: assert index.dtype.is_int(), "index must be an integer tensor" if not src.dtype.is_floating(): raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype}") - + rank = len(src.type.shape) assert len(index.type.shape) == rank, "source and index tensors must have the same rank" @@ -319,6 +319,7 @@ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> gather = builder.create_gather(src.handle, index.handle, axis) return wrap_tensor(gather, src.type.scalar, index.type.shape) +# FIXME: non-exist in semantic.py def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: assert(len(ful.shape) == len(offsets)) assert(len(ful.shape) == len(sizes)) @@ -484,6 +485,7 @@ def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: else: raise TypeError(f"Unexpected dtype {dtype}") +# FIXME: non-exist in semantic.py def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: assert(len(ful.shape) == len(offsets)) assert(len(ful.shape) == len(sizes)) @@ -495,6 +497,7 @@ def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], st out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) return tl.tensor(out, ret_type) +# FIXME: non-exist in semantic.py def get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): if len(src.shape) != len(indice): raise ValueError("Indice's rank must be equal to src tensor's rank") @@ -583,6 +586,7 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# FIXME: non-exist in semantic.py def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): if not hint_val: hint_val = builder.get_unit_attr() @@ -600,19 +604,21 @@ def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): builder.create_annotation(ptr.handle, hint_name, hint_val) +# FIXME: non-exist in semantic.py def custom_op(builder: ir.builder, op_name: str, **kwargs): if op_name == "sync_block_all": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) elif op_name == "sync_block_set": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - + elif op_name == "sync_block_wait": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - + raise ValueError(f"Unsupported custom op: {op_name}") +# FIXME: non-exist in semantic.py def sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): """ Triton sort 操作 @@ -692,7 +698,7 @@ def _str_to_fp_type(float_format: Optional[str]): return ir.F8F6F4TY.FP16 raise ValueError(f"Invalid float format: {float_format}.") - +# FIXME: non-exist in semantic.py def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) if triton_ty is None: @@ -720,7 +726,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) @@ -759,6 +765,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te +# FIXME: non-exist in semantic.py def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if dtype is None: raise ValueError("dtype must be specified when value is not a tensor") @@ -770,6 +777,7 @@ def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: return tl.tensor(value, dtype) +# FIXME: non-exist in semantic.py def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" @@ -777,6 +785,7 @@ def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: return scalar_constant(value, dtype, builder) +# FIXME: non-exist in semantic.py def make_tensor_descriptor( base: tl.tensor, shape: List[tl.tensor], @@ -805,7 +814,7 @@ def make_tensor_descriptor( strides[-1] = _unwrap_if_constexpr(strides[-1]) if strides[-1] != 1: raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") - + shape = [make_scalar(x, tl.int32, builder) for x in shape] strides = [make_scalar(x, tl.int64, builder) for x in strides] diff --git a/python/triton/language/tensor_descriptor.py b/python/triton/language/tensor_descriptor.py index 077c889e1..83ed64a85 100644 --- a/python/triton/language/tensor_descriptor.py +++ b/python/triton/language/tensor_descriptor.py @@ -21,7 +21,7 @@ _str_to_eviction_policy, ) -from ._utils import validate_block_shape, get_primitive_bitwidth +from ._utils import validate_block_shape def _unwrap_if_constexpr(o): @@ -227,6 +227,9 @@ def __init__(self, name): name = _unwrap_if_constexpr(name) self.name = name assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_func_specialization + get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth") self.primitive_bitwidth = get_primitive_bitwidth(name) self.itemsize = self.primitive_bitwidth // 8 if name in dtype.SINT_TYPES: diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 67753e129..95061c610 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -4,7 +4,7 @@ import os import time import inspect -from typing import Dict, List +from typing import Dict from .jit import KernelInterface from .errors import OutOfResources @@ -28,6 +28,7 @@ def __init__( rep=None, use_cuda_graph=False, do_bench=None, + # flagtree backend specialization auto_profile_dir=None, ): """ @@ -36,9 +37,15 @@ def __init__( 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if not configs: self.configs = [ - Config({}) + flagtree_backend_specialization('get_spec_default_Autotuner_configs') + if flagtree_backend_specialization('has_spec_default_Autotuner_configs') + else Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) ] else: self.configs = configs @@ -100,7 +107,9 @@ def _post_hook(kwargs, exception): self.num_warmups = warmup self.num_reps = rep self.use_cuda_graph = use_cuda_graph - self.auto_profile_dir = auto_profile_dir + + # flagtree backend specialization + flagtree_backend_specialization('set_Autotuner_auto_profile_dir', self, auto_profile_dir) # If we got explicitly called via the old interface, raise a warning # and proceed with the old behavior. @@ -133,7 +142,7 @@ def _post_hook(kwargs, exception): self.do_bench = do_bench def _bench(self, *args, config, **meta): - from ..compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError + from ..compiler.errors import CompileTimeAssertionFailure # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -163,46 +172,15 @@ def kernel_call(): self.post_hook(full_nargs, exception=None) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: + except (OutOfResources, CompileTimeAssertionFailure) + \ + flagtree_backend_specialization("ext_Autotuner_do_bench_MLIRCompilationError") as e: return [float("inf"), float("inf"), float("inf")] - def _profile(self, *args, config, **meta): - from triton.testing import do_bench_npu - - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} - - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(full_nargs) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(full_nargs, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise - - self.post_hook(full_nargs, exception=None) - - do_bench_npu( - kernel_call, prof_dir=self.auto_profile_dir, keep_res=True - ) - def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) used_cached_result = True @@ -234,8 +212,10 @@ def run(self, *args, **kwargs): print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - if not used_cached_result and self.auto_profile_dir is not None: - self._profile(*args, config=self.best_config, **kwargs) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Autotuner_profile', self, used_cached_result, args, kwargs) + if config.pre_hook is not None: full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} config.pre_hook(full_nargs) @@ -315,18 +295,14 @@ def __init__(self, kwargs, num_warps=None, num_stages=None, num_ctas=None, num_b self.maxnreg = maxnreg self.pre_hook = pre_hook - - # BiShengIR Options allowed for autotune - self.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True - self.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False - self.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False - self.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit - self.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 - self.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True - self.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 - self.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('set_Config_BiShengIR_options', self, bishengir_options) def all_kwargs(self): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return { **self.kwargs, **{ k: v @@ -339,17 +315,7 @@ def all_kwargs(self): ("reg_dec_producer", self.reg_dec_producer), ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), - - ("multibuffer", self.multibuffer), - ("enable_hivm_auto_cv_balance", self.enable_hivm_auto_cv_balance), - ("unit_flag", self.unit_flag), - ("limit_auto_multi_buffer_only_for_local_buffer", \ - self.limit_auto_multi_buffer_only_for_local_buffer), - ("limit_auto_multi_buffer_of_local_buffer", self.limit_auto_multi_buffer_of_local_buffer), - ("set_workspace_multibuffer", self.set_workspace_multibuffer), - ("tile_mix_vector_loop", self.tile_mix_vector_loop), - ("tile_mix_cube_loop", self.tile_mix_cube_loop), - ) if v is not None + ) + flagtree_backend_specialization('ext_Config_all_kwargs', self) if v is not None } } @@ -366,15 +332,10 @@ def __str__(self): res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") - res.append(f"multibuffer: {self.multibuffer}") - res.append(f"enable_hivm_auto_cv_balance: {self.enable_hivm_auto_cv_balance}") - res.append(f"unit_flag: {self.unit_flag}") - res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ - {self.limit_auto_multi_buffer_only_for_local_buffer}") - res.append(f"limit_auto_multi_buffer_of_local_buffer: {self.limit_auto_multi_buffer_of_local_buffer}") - res.append(f"set_workspace_multibuffer: {self.set_workspace_multibuffer}") - res.append(f"tile_mix_vector_loop: {self.tile_mix_vector_loop}") - res.append(f"tile_mix_cube_loop: {self.tile_mix_cube_loop}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Config_to_str', res, self) + return ", ".join(res) @@ -440,15 +401,16 @@ def kernel(x_ptr, x_size, **META): """ def decorator(fn): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization if split_params or tiling_params: - from .autotiling_tuner import AutoTilingTuner - return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, - split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, - dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) - else: - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + return flagtree_backend_specialization('new_AutoTilingTuner', fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction) + + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c3b97a764..58228de35 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -58,3 +58,25 @@ def reset_active(self): driver = DriverConfig() + + +# flagtree backend specialization +def flagtree_backend_specialization(function_name: str, *args, **kwargs): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func(*args, **kwargs) + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") + + +# flagtree backend func specialization +def flagtree_backend_func_specialization(function_name: str): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 08422611f..aa7d8b800 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -6,14 +6,11 @@ import os import re import textwrap -import tokenize from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver -from ..backends.ascend.compiler import AscendAttrsDescriptor from types import ModuleType -from io import StringIO TRITON_MODULE = __name__[:-len(".runtime.jit")] @@ -331,6 +328,7 @@ def __getitem__(self, grid) -> T: memorizes the grid. """ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) def serialize_specialization_data(name, signature, constants, attrs, options, key): @@ -568,7 +566,10 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend device = driver.active.get_current_device() - if ('stream' not in kwargs.keys()): + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_set_stream_in_kwargs", kwargs): stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() backend = make_backend(target) @@ -593,20 +594,17 @@ def run(self, *args, grid, warmup, **kwargs): # deprecated arguments assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" assert "device" not in kwargs, "device option is deprecated; current device will be used" - # assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_stream_option_deprecated"): + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: if k not in options.__dict__: raise KeyError("Keyword argument %s was specified but unrecognised" % k) - ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ - "allowed_dot_input_precisions", "multibuffer", "stream"] - not_work_params = [] - for k in kwargs: - if k in ignor_params: - continue - elif k in excess_kwargs: - not_work_params.append(k) - if len(not_work_params) != 0: - print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + + flagtree_backend_specialization("ignore_params_in_JITFunction_run", kwargs, excess_kwargs) bound_vals = tuple(bound_args.values()) @@ -660,16 +658,17 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - grid_all_size = grid_0 * grid_1 * grid_2 - if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": - if grid_all_size > 65535: - raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") - if ('stream' in kwargs.keys()): - stream = kwargs["stream"] + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("check_grid_size", grid_0, grid_1, grid_2) + stream = flagtree_backend_specialization("set_stream_from_kwargs", kwargs, stream) + # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) - # explicitly define run method and load kernel binary - kernel._init_handles() + + flagtree_backend_specialization("explicit_load_kernel_library", kernel) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) return kernel @@ -749,6 +748,7 @@ def warmup(self, *args, grid, **kwargs): def preload(self, specialization_data): from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() @@ -761,7 +761,13 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AscendAttrsDescriptor.from_dict(deserialized_obj['attrs'])) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + src = ASTSource(self, signature, constants, + flagtree_backend_specialization('get_JITFunction_spec_attr', deserialized_obj) + if flagtree_backend_specialization('is_JITFunction_spec_attr') + else AttrsDescriptor.from_dict(deserialized_obj['attrs'])) options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() @@ -775,29 +781,18 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): - # Maps line numbers to comment hints - line_flagtree_hints = {} - code_str = self.src - g = tokenize.generate_tokens(StringIO(code_str).readline) - for tok_type, tok_text, start, end, _ in g: - if tok_type == tokenize.COMMENT: - comment = tok_text.replace(" ", "").strip() - if comment.startswith('#@hint:'): - flagtree_hints = comment[len('#@hint:'):].strip() - # Record the line number of the comment - line_num = start[0] - line_flagtree_hints[line_num] = flagtree_hints - - # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + line_flagtree_hints = flagtree_backend_specialization('maps_line_numbers_to_comment_hints', self) tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) - # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints - + # flagtree backend specialization + flagtree_backend_specialization('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs): diff --git a/python/triton/testing.py b/python/triton/testing.py index b929ef22c..46f01e784 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,8 +1,6 @@ import functools import os import subprocess -import multiprocessing -import os import sys from contextlib import contextmanager from typing import Any, Dict, List @@ -114,10 +112,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m assert return_mode in ["min", "max", "mean", "median", "all"] import torch - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' - if torch.npu.is_available() and enable_bench_npu: - avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) - return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_do_bench_npu'): + return flagtree_backend_specialization('ext_do_bench_npu') di = runtime.driver.active.get_device_interface() @@ -164,99 +162,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) -def collect_files(base_dir): - import pandas as pd - for root, dirs, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] - if not triton_rows.empty: - return triton_rows['Avg Time(us)'].values[0] - return float('inf') - return float('inf') - - -def collect_single(base_dir: str, key: str = None) -> float: - if not os.path.exists(base_dir): - return float('inf') - - import pandas as pd - for root, _, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - if key is not None: - key_rows = df[df['OP Type'].str.startswith(key, na=False)] - if not key_rows.empty: - return key_rows['Avg Time(us)'].values[0] - return float('inf') - else: - # default: read the first row except header - return df.loc[0, 'Avg Time(us)'] - - return float('inf') - - -def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): - import torch - import torch_npu - from datetime import datetime, timezone - - # warmup kernel - fn() - torch.npu.synchronize() - - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) - skip_first = 1 - wait = 0 - repeat = 1 - total = skip_first + (wait + warmup + active) * repeat - - if prof_dir is not None: - torch_path = prof_dir - else: - process = multiprocessing.current_process() - pid = process.pid - process_name = process.name - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") - torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.NPU - ], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - for _ in range(total): - fn() - prof.step() - torch.npu.synchronize() - - time = collect_single(torch_path) - - if not keep_res: - import shutil - if os.path.exists(torch_path): - shutil.rmtree(torch_path) - - return time def assert_close(x, y, atol=None, rtol=None, err_msg=''): """ @@ -431,6 +336,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b ax.legend() ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -609,205 +515,56 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops + # Patch the triton language API here because triton's __init__.py # import testing in the last stages. from .language.tensor_descriptor import ( tensor_descriptor, tensor_descriptor_type, ) - -from .language.core_ext import ( - dot, - cast, - gather, - get_element, - insert_slice, - extract_slice, - trans, - __add__, - __radd__, - __sub__, - __rsub__, - __mul__, - __rmul__, - __lshift__, - __rshift__, - parallel, - compile_hint, - make_tensor_descriptor, - load_tensor_descriptor, - store_tensor_descriptor, - multibuffer, - sync_block_all, - sync_block_set, - sync_block_wait, - dtype_to_ir, - sort -) from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 -from .language.math_ext import ( - umulhi, - exp, - exp2, - log, - log2, - cos, - sin, - sqrt, - sqrt_rn, - rsqrt, - div_rn, - erf, - tanh, - floor, - ceil, - _check_dtype, - fma, -) -from .language.semantic_ext import ( - arange, - floordiv, - atom_red_typechecking_impl, - atomic_cas, - atomic_max, - atomic_min, - _load_legacy, - maximum, - minimum, - mod, - invert, - logical_and, - logical_or, - not_, - and_, - or_, - xor_, - minus, - dot_scaled, -) from . import language -language.cast = cast -language.dot = dot language.flip = flip language.sigmoid = sigmoid language.softmax = softmax -language.gather = gather -language.insert_slice = insert_slice -language.extract_slice = extract_slice -language.get_element = get_element -language.tensor.__add__ = __add__ -language.tensor.__radd__ = __radd__ -language.tensor.__sub__ = __sub__ -language.tensor.__rsub__ = __rsub__ -language.tensor.__mul__ = __mul__ -language.tensor.__rmul__ = __rmul__ -language.tensor.__lshift__ = __lshift__ -language.tensor.__rshift__ = __rshift__ -language.trans = trans -language.parallel = parallel -language.compile_hint = compile_hint -language.sort = sort -language.multibuffer = multibuffer -language.sync_block_all = sync_block_all -language.sync_block_set = sync_block_set -language.sync_block_wait = sync_block_wait -language.make_tensor_descriptor = make_tensor_descriptor language.tensor_descriptor = tensor_descriptor language.tensor_descriptor_type = tensor_descriptor_type -language.load_tensor_descriptor = load_tensor_descriptor -language.store_tensor_descriptor = store_tensor_descriptor - -language.semantic.arange = arange -language.semantic.floordiv = floordiv -language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl -language.semantic.atomic_cas = atomic_cas -language.semantic.atomic_max = atomic_max -language.semantic.atomic_min = atomic_min -language.semantic._load_legacy = _load_legacy -language.semantic.maximum = maximum -language.semantic.minimum = minimum -language.semantic.invert = invert -language.semantic.logical_and = logical_and -language.semantic.logical_or = logical_or -language.semantic.mod = mod -language.semantic.not_ = not_ -language.semantic.and_ = and_ -language.semantic.or_ = or_ -language.semantic.xor_ = xor_ -language.semantic.minus = minus -language.semantic.dot_scaled = dot_scaled - -language.umulhi = umulhi -language.exp = exp -language.exp2 = exp2 -language.log = log -language.log2 = log2 -language.cos = cos -language.sin = sin -language.sqrt = sqrt -language.sqrt_rn = sqrt_rn -language.rsqrt = rsqrt -language.div_rn = div_rn -language.erf = erf -language.tanh = tanh -language.floor = floor -language.ceil = ceil -language.core.dtype.to_ir = dtype_to_ir -language.fma = fma -language.math.umulhi = umulhi -language.math.exp = exp -language.math.exp2 = exp2 -language.math.log = log -language.math.log2 = log2 -language.math.cos = cos -language.math.sin = sin -language.math.sqrt = sqrt -language.math.sqrt_rn = sqrt_rn -language.math.rsqrt = rsqrt -language.math.div_rn = div_rn -language.math.erf = erf -language.math.tanh = tanh -language.math.floor = floor -language.math.ceil = ceil -language.math._check_dtype = _check_dtype -language.math.fma = fma -language.math.isnan = language.extra.ascend.libdevice.isnan -language.math.isinf = language.extra.ascend.libdevice.isinf -language.math.reciprocal = language.extra.ascend.libdevice.reciprocal -language.math.log1p = language.extra.ascend.libdevice.log1p -language.math.relu = language.extra.ascend.libdevice.relu -language.math.tan = language.extra.ascend.libdevice.tan -language.math.atan = language.extra.ascend.libdevice.atan + +language.umulhi = language.extra.ascend.libdevice.umulhi +language.exp = language.extra.ascend.libdevice.exp +language.exp2 = language.extra.ascend.libdevice.exp2 +language.log = language.extra.ascend.libdevice.log +language.log2 = language.extra.ascend.libdevice.log2 +language.cos = language.extra.ascend.libdevice.cos +language.sin = language.extra.ascend.libdevice.sin +language.sqrt = language.extra.ascend.libdevice.sqrt +language.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn +language.rsqrt = language.extra.ascend.libdevice.rsqrt +language.div_rn = language.extra.ascend.libdevice.div_rn +language.erf = language.extra.ascend.libdevice.erf +language.tanh = language.extra.ascend.libdevice.tanh +language.floor = language.extra.ascend.libdevice.floor +language.ceil = language.extra.ascend.libdevice.ceil +language.fma = language.extra.ascend.libdevice.fma +language.math.umulhi = language.extra.ascend.libdevice.umulhi +language.math.exp = language.extra.ascend.libdevice.exp +language.math.exp2 = language.extra.ascend.libdevice.exp2 +language.math.log = language.extra.ascend.libdevice.log +language.math.log2 = language.extra.ascend.libdevice.log2 +language.math.cos = language.extra.ascend.libdevice.cos +language.math.sin = language.extra.ascend.libdevice.sin +language.math.sqrt = language.extra.ascend.libdevice.sqrt +language.math.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn +language.math.rsqrt = language.extra.ascend.libdevice.rsqrt +language.math.div_rn = language.extra.ascend.libdevice.div_rn +language.math.erf = language.extra.ascend.libdevice.erf language.math.tanh = language.extra.ascend.libdevice.tanh -language.math.ilogb = language.extra.ascend.libdevice.ilogb -language.math.ldexp = language.extra.ascend.libdevice.ldexp -language.math.pow = language.extra.ascend.libdevice.pow -language.math.flip = language.extra.ascend.libdevice.flip -language.math.atan2 = language.extra.ascend.libdevice.atan2 -language.math.div_rz = language.extra.ascend.libdevice.div_rz -language.math.fmod = language.extra.ascend.libdevice.fmod -language.math.trunc = language.extra.ascend.libdevice.trunc -language.math.round = language.extra.ascend.libdevice.round +language.math.floor = language.extra.ascend.libdevice.floor +language.math.ceil = language.extra.ascend.libdevice.ceil +language.math._check_dtype = language.extra.ascend.libdevice._check_dtype +language.math.fma = language.extra.ascend.libdevice.fma language.math.finitef = finitef language.math.isfinited = isfinited language.math.rint = rint language.math.atan2 = atan2 -language.extra.ascend.libdevice.umulhi = language.math.umulhi -language.extra.ascend.libdevice.exp = language.math.exp -language.extra.ascend.libdevice.exp2 = language.math.exp2 -language.extra.ascend.libdevice.log = language.math.log -language.extra.ascend.libdevice.log2 = language.math.log2 -language.extra.ascend.libdevice.cos = language.math.cos -language.extra.ascend.libdevice.sin = language.math.sin -language.extra.ascend.libdevice.sqrt = language.math.sqrt -language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn -language.extra.ascend.libdevice.rsqrt = language.math.rsqrt -language.extra.ascend.libdevice.div_rn = language.math.div_rn -language.extra.ascend.libdevice.erf = language.math.erf -language.extra.ascend.libdevice.tanh = language.math.tanh -language.extra.ascend.libdevice.floor = language.math.floor -language.extra.ascend.libdevice.ceil = language.math.ceil -language.extra.ascend.libdevice.fdiv = language.math.fdiv -language.extra.ascend.libdevice.fma = language.math.fma -language.extra.ascend.libdevice.abs = language.math.abs diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 8ea0873a8..603bb84cc 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -105,6 +105,9 @@ class NPUDriver(DriverBase): def __init__(self): self.utils = NPUUtils() self.launcher_cls = NPULauncher + # flagtree backend specialization + from triton.backends.ascend import flagtree_backend_specialization + self.flagtree_backend_specialization = flagtree_backend_specialization super().__init__() @classmethod diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py new file mode 100644 index 000000000..3c7ddc671 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -0,0 +1,135 @@ +from .triton.compiler.compiler import * +from .triton.compiler.errors import * +from .triton.compiler.code_generator import * +from .triton.runtime.jit import * +from .triton.runtime.autotuner import * +from .triton.language._utils import * +from .triton.language.core import * +from .triton.language.semantic import * +from .triton.testing import * + +__all__ = [ + # compiler.compiler + 'ext_ASTSource_attrs', + 'opt_ascend_compile_speed', + 'set_CompiledKernel_metadata_stream', + 'handle_compile_error', + 'is_CompiledKernel_getattribute_need_init_handles', + # compiler.code_generator + 'anno_CodeGenerator_visit_Assign', + 'ext_CodeGenerator_visit_Assign_hint_anno', + 'init_bind_sub_block', + 'is_visit_For_support_parallel', + 'set_bind_sub_block_when_parallel', + 'check_override_bind_sub_block', + 'forop_setattr_for_bind_sub_block', + 'need_repr_in_CodeGenerator_CompilationError', + # runtime.jit + 'is_set_stream_in_kwargs', + 'is_stream_option_deprecated', + 'ignore_params_in_JITFunction_run', + 'set_stream_from_kwargs', + 'check_grid_size', + 'explicit_load_kernel_library', + 'is_JITFunction_spec_attr', + 'get_JITFunction_spec_attr', + 'maps_line_numbers_to_comment_hints', + 'attach_line_number_to_comment_mapping', + # runtime.autotuner + 'set_Autotuner_auto_profile_dir', + 'has_spec_default_Autotuner_configs', + 'get_spec_default_Autotuner_configs', + 'ext_Autotuner_do_bench_MLIRCompilationError', + 'ext_Autotuner_profile', + 'set_Config_BiShengIR_options', + 'ext_Config_all_kwargs', + 'ext_Config_to_str', + 'new_AutoTilingTuner', + # language._utils + 'is_block_shape_check_power_of_two', + 'get_primitive_bitwidth', + # language.core + "ext_cast_set_overflow_modes", + "ext_cast_check_overflow_mode", + "ext_trans_unwrap_iterable", + "check_dot_deprecated_param_allow_tf32", + "check_dot_invalid_input_precision", + "ext_core_gather", + "ext_core_insert_slice", + "ext_core_extract_slice", + "ext_core_get_element", + "ext_core_add", + "ext_core_radd", + "ext_core_sub", + "ext_core_rsub", + "ext_core_mul", + "ext_core_rmul", + "ext_core_lshift", + "ext_core_rshift", + "ext_core_compile_hint", + "ext_core_sort", + "ext_core_multibuffer", + "ext_core_sync_block_all", + "ext_core_sync_block_set", + "ext_core_sync_block_wait", + "ext_core_load_tensor_descriptor", + "ext_core_store_tensor_descriptor", + "ext_core_make_tensor_descriptor", + "ext_core_dtype_to_ir", + # language.semantic + "is_arange_check_power_of_two", + "check_arange_less_than_max_numel", + "is_cast_src_dst_scalar_type_equal", + "check_unsupported_fp8_fp64", + "ext_dot_lhs_supported_type", + "ext_dot_rhs_supported_type", + "dot_check_hf32_input_precision", + "is_dot_check_max_num_imprecise_acc", + "reset_dot_max_num_imprecise_acc", + "check_was_bool_to_int8_dtype", + "check_was_bool_to_int8_dtype_and_cast", + "check_unexpected_dtype_float", + "check_unexpected_dtype_bool", + "set_load_legacy_other_input", + "cast_back_when_load_legacy_ptr_is_bool", + "set_attr_was_bool_to_int8", + "is_atomic_need_original_check", + "ext_atomic_element_typechecking", + "is_atomic_cas_need_element_bitwidth_check", + "ext_atomic_cas_element_typechecking", + "is_atomic_max_no_bitcast", + "is_atomic_min_no_bitcast", + "atomic_max_returning_tensor", + "atomic_min_returning_tensor", + "is_float_format_support_bf16", + "is_float_format_support_fp16", + "ext_dot_scaled_validate_lhs_dtype", + "ext_dot_scaled_validate_rhs_dtype", + "ext_dot_scaled_check_same_dtype", + "is_dot_scaled_need_original_check", + "ext_dot_scaled_check_lhs_rhs_format", + "dot_scaled_recheck_rhs_scale_is_none", + "dot_scaled_check_lhs_scale_is_none", + "is_dot_scaled_support_rhs_scale", + "check_dot_scaled_lhs_scale_dtype", + "check_dot_scaled_rhs_scale_dtype", + "dot_scaled_lhs_bitcast_to_fp_type", + "dot_scaled_rhs_bitcast_to_fp_type", + "check_dot_scaled_dimension", + "check_dot_scaled_pack_size", + "set_dot_scaled_lhs_scale_handle", + "ext_semantic_gather", + "ext_semantic_insert_slice", + "ext_semantic_extract_slice", + "ext_semantic_get_element", + "ext_semantic_compile_hint", + "ext_semantic_custom_op", + "ext_semantic_sort", + "ext_semantic_scalar_constant", + "ext_semantic_make_scalar", + "ext_semantic_make_tensor_descriptor", + # testing + 'is_do_bench_npu', + 'ext_do_bench_npu', + 'patch_triton_language' +] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py new file mode 100644 index 000000000..b90a9f99e --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py @@ -0,0 +1,63 @@ +def anno_CodeGenerator_visit_Assign(): + # flagtree: First, do normal assignment processing + return + +def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + +def init_bind_sub_block(): + return None + +def is_visit_For_support_parallel(): + return True + +def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): + import triton.language as language + if (IteratorClass is language.parallel): + return iterator.bind_sub_block + return bind_sub_block + +def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + +def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + +def need_repr_in_CodeGenerator_CompilationError(): + return True diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py new file mode 100644 index 000000000..ef1426020 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py @@ -0,0 +1,32 @@ +def ext_ASTSource_attrs(ast_source): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + if ast_source.attrs is None: + ast_source.attrs = AscendAttrsDescriptor() + +def opt_ascend_compile_speed(file_name, metadata_path, fn_cache_manager): + import os + compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') + if (compile_speed_opt): + ttir_path = f"{file_name}.ttir" + if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): + # Already compile once but failed. So directly return + raise Exception("already failed once") + +def set_CompiledKernel_metadata_stream(compiled_kernel, stream): + if stream is None: + return stream + return compiled_kernel.metadata.stream + +def handle_compile_error(e, ext): + from .errors import MLIRCompilationError + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) + raise MLIRCompilationError(stage_name, error_detail) + +def is_CompiledKernel_getattribute_need_init_handles(): + return False diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py new file mode 100644 index 000000000..b1ef43a3b --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py @@ -0,0 +1,20 @@ +import importlib.util +import sys +from typing import Optional +from triton.compiler.errors import TritonError + +class MLIRCompilationError(TritonError): + def __init__(self, stage_name: Optional[str], message: Optional[str] = None): + self.stage_name = stage_name + self.message = f"\n" \ + f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ + f"[{self.stage_name}] encounters error:\n" \ + f"{self.filter_message(message)}" \ + f"{self.format_line_delim('[ERROR][Triton][END]')}" + def __str__(self): + return self.message + def filter_message(self, message): + # Content starting from "Stack dump without symbol names" means nothing to the users + return message.split("Stack dump without symbol names")[0] + def format_line_delim(self, keyword): + return f"///------------------{keyword}------------------\n" diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py new file mode 100644 index 000000000..4810b0f6c --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Union, Dict +if TYPE_CHECKING: + from triton.language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] + + +def is_block_shape_check_power_of_two(): + return False + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] + + diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py new file mode 100644 index 000000000..3bedb6383 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py @@ -0,0 +1,334 @@ +from typing import List, Sequence, Union +from triton._C.libtriton import ir +import triton.language.semantic as semantic +from triton.language.core import ( + _unwrap_iterable, + _constexpr_to_value, + constexpr, + tensor, + check_bit_width, + _unwrap_if_constexpr, + add, + sub, + mul, +) + +from triton.language.tensor_descriptor import tensor_descriptor, tensor_descriptor_base + +def ext_cast_set_overflow_modes(): + return ["trunc", "saturate"] + +def ext_cast_check_overflow_mode(overflow_mode, overflow_modes, ret, _builder): + if overflow_mode is not None: + if overflow_mode in overflow_modes: + semantic.compile_hint(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + +def ext_trans_unwrap_iterable(dims): + return _unwrap_iterable(dims) + +def check_dot_deprecated_param_allow_tf32(allow_tf32): + assert ( + not allow_tf32 + ), "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + +def check_dot_invalid_input_precision(input_precision): + assert input_precision not in [ + "tf32", + "tf32x3", + ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + +def ext_core_gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + +def ext_core_insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + out = semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) + return out + +def ext_core_extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + sub = semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) + return sub + +def ext_core_get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + assert len(src.shape) > 0 + new_indice = [ + semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i + for i in indice + ] + return semantic.get_element(src, new_indice, _builder) + +def ext_core_add(self, other, _builder=None): + return add(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_radd(self, other, _builder=None): + return add(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_sub(self, other, _builder=None): + return sub(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_rsub(self, other, _builder=None): + return sub(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_mul(self, other, _builder=None): + return mul(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_rmul(self, other, _builder=None): + return mul(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_lshift(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return semantic.shl(self, other, _builder) + +def ext_core_rshift(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + other = _unwrap_if_constexpr(other) + check_bit_width(self, other) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + +def ext_core_compile_hint(ptr, hint_name, hint_val=None, _builder=None): + def _unwrap(val): + return _unwrap_if_constexpr(val) if val else val + + hint_name = _constexpr_to_value(hint_name) + assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" + if isinstance(hint_val, list): + hint_val = [_unwrap(val) for val in hint_val] + else: + hint_val = _unwrap(hint_val) + hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val + semantic.compile_hint(ptr, hint_name, hint_val, _builder) + +def ext_core_sort(ptr, dim=-1, descending=False, _builder=None): + """ + Triton sort 前端接口 + + 参数: + ptr: tl.tensor,输入张量 + dim: int 或 tl.constexpr[int],排序维度 + descending: bool 或 tl.constexpr[bool],是否降序 + _builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = semantic.sort(ptr, dim, descending, _builder) + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + semantic.compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + +def ext_core_multibuffer(src: tensor, size, _builder=None): + """ + Set multi_buffer for an existing tensor + :src: tensor set to bufferize multiple time + :size: number of copies + """ + buffer_size = _constexpr_to_value(size) + assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" + semantic.compile_hint(src, "multi_buffer", buffer_size, _builder) + +def ext_core_sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode == "all_cube" or mode == "all_vector" or mode == "all", f"ERROR: mode = {mode}, only supports all_cube/all_vector/all" + semantic.custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) + +def ext_core_sync_block_set(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) + +def ext_core_sync_block_wait(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) + +def ext_core_load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], + _builder=None) -> tensor: + """Load a block of data from a tensor descriptor.""" + return desc.load(offsets, _builder=_builder) + +def ext_core_store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, + _builder=None) -> tensor: + """Store a block of data to a tensor descriptor.""" + return desc.store(offsets, value, _builder=_builder) + +def ext_core_make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> tensor_descriptor: + """Make a tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M // M_BLOCK, N // N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) + +def ext_core_dtype_to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + raise ValueError(f'unexpected type fp8.') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py new file mode 100644 index 000000000..69821759d --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py @@ -0,0 +1,390 @@ +from typing import List +import triton.language as tl +from triton._C.libtriton import ir +from triton.language import cast +from triton.language.semantic import to_tensor, bitcast, wrap_tensor +from triton.language._utils import TRITON_MAX_TENSOR_NUMEL +from triton.language.tensor_descriptor import ( + _unwrap_if_constexpr, + _unwrap_shape, + block_type, + tensor_descriptor +) + +def is_arange_check_power_of_two(): + return False + +def check_arange_less_than_max_numel(range): + if range > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") + +def is_cast_src_dst_scalar_type_equal(src_sca_ty, dst_sca_ty): + if src_sca_ty == dst_sca_ty: + return True + return False + +def check_unsupported_fp8_fp64(src_sca_ty, dst_sca_ty): + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + +def ext_dot_lhs_supported_type(): + return (tl.int1,) + +def ext_dot_rhs_supported_type(): + return (tl.int1,) + +def dot_check_hf32_input_precision(input_precision, ir, lhs, rhs, ret_scalar_ty): + if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): + if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): + raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + +def is_dot_check_max_num_imprecise_acc(): + return False + +def reset_dot_max_num_imprecise_acc(max_num_imprecise_acc): + max_num_imprecise_acc = 0 + return max_num_imprecise_acc + +def check_was_bool_to_int8_dtype(input): + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + +def check_was_bool_to_int8_dtype_and_cast(input, builder): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + return cast(input, tl.int1, builder) + +def check_unexpected_dtype_float(input): + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + +def check_unexpected_dtype_bool(dtype): + if dtype.is_bool(): + raise TypeError(f"Unexpected dtype {dtype}") + +def set_load_legacy_other_input(builder): + return to_tensor(0, builder) + +def cast_back_when_load_legacy_ptr_is_bool(): + return False + +def set_attr_was_bool_to_int8(ret, is_bool): + if is_bool: + ret.was_bool_to_int8 = True + +def is_atomic_need_original_check(): + return False + +def ext_atomic_element_typechecking(element_ty, op): + # Add `tl.int64` restriction for NPU + if element_ty in [tl.int1, tl.int64, tl.float16, tl.float32, tl.float64, tl.bfloat16] and op in ['or', 'xor']: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32.") + if element_ty in [tl.int1, tl.int64, tl.float64, tl.bfloat16] and op == 'xchg': + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32.") + if element_ty in [tl.int1, tl.int64, tl.float64]: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32, bfloat16.") + +def is_atomic_cas_need_element_bitwidth_check(): + return False + +def ext_atomic_cas_element_typechecking(element_ty): + if element_ty in [tl.int1, tl.int8, tl.float64, tl.bfloat16]: + raise ValueError(f"atomic_cas does not support {str(element_ty)}. " + "All support dtypes are int16, int32, int64, float16, float32.") + +def is_atomic_max_no_bitcast(): + return True + +def is_atomic_min_no_bitcast(): + return True + +def atomic_max_returning_tensor(ir, ptr, val, mask, sem, scope, builder): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + +def atomic_min_returning_tensor(ir, ptr, val, mask, sem, scope, builder): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + +def is_float_format_support_bf16(): + return True + +def is_float_format_support_fp16(): + return True + +def ext_dot_scaled_validate_lhs_dtype(lhs): + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + +def ext_dot_scaled_validate_rhs_dtype(rhs): + assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + +def ext_dot_scaled_check_same_dtype(lhs, rhs): + assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" + +def is_dot_scaled_need_original_check(): + return False + +def ext_dot_scaled_check_lhs_rhs_format(lhs_format, rhs_format): + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + +def dot_scaled_recheck_rhs_scale_is_none(rhs_scale, rhs_scale_is_none): + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + return rhs_scale_is_none + +def dot_scaled_check_lhs_scale_is_none(lhs_scale): + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + return lhs_scale_is_none + +def is_dot_scaled_support_rhs_scale(): + return True + +def check_dot_scaled_lhs_scale_dtype(lhs_scale): + assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + +def check_dot_scaled_rhs_scale_dtype(rhs_scale, rhs_scale_is_none): + if not rhs_scale_is_none: + assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + +def _bitcast_to_fp_type(val, float_format, builder): + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + +def dot_scaled_lhs_bitcast_to_fp_type(lhs, lhs_format, builder): + lhs_format: str = lhs_format.value + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + return lhs + +def dot_scaled_rhs_bitcast_to_fp_type(rhs, rhs_format, builder): + rhs_format: str = rhs_format.value + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) + return rhs + +def check_dot_scaled_dimension(lhs, rhs): + assert lhs.type.shape[-1] == rhs.type.shape[-2], ( + f"lhs last dimension (columns) {lhs.shape[-1]} " + f"must equal rhs penultimate dimension (rows) {rhs.shape[-2]}" + ) + +def check_dot_scaled_pack_size(PACKED_A, K, lhs_format, lhs, rhs): + lhs_format: str = lhs_format.value + PACKED_B = 2 if lhs_format == "e2m1" else 1 + assert K * PACKED_B == PACKED_A * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + +def set_dot_scaled_lhs_scale_handle(lhs_scale, lhs_scale_is_none): + return None if lhs_scale_is_none else lhs_scale.handle + +def ext_semantic_gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + if not src.dtype.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype}") + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + +def ext_semantic_insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def ext_semantic_extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def ext_semantic_get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + new_indice = [i.handle for i in indice] + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + +def ext_semantic_compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): + if not hint_val: + hint_val = builder.get_unit_attr() + elif isinstance(hint_val, bool): + hint_val = builder.get_bool_attr(hint_val) + elif isinstance(hint_val, int): + hint_val = builder.get_int32_attr(hint_val) + elif isinstance(hint_val, tl.constexpr): + hint_val = builder.get_str_attr(hint_val.value) + elif isinstance(hint_val, list): + # only support i64 array attr for now + hint_val = builder.get_i64_array_attr(hint_val) + else: + raise ValueError(f"Unsupported hint value type: {type(hint_val)}") + builder.create_annotation(ptr.handle, hint_name, hint_val) + +def ext_semantic_custom_op(builder: ir.builder, op_name: str, **kwargs): + if op_name == "sync_block_all": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) + + elif op_name == "sync_block_set": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + elif op_name == "sync_block_wait": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + raise ValueError(f"Unsupported custom op: {op_name}") + +def ext_semantic_sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): + """ + Triton sort 操作 + + 参数: + ptr: tl.tensor,输入张量 + dim: int,排序维度,必须是尾轴(最后一维) + descending: bool 或 constexpr,是否降序 + builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + allowed_types = {tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32} + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"tt.sort only supports int8, int16, bfloat16, float16, float32, " + f"but got {ptr.type}" + ) + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("tt.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError( + f"tt.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}" + ) + dim = last_dim + else: + if dim != -1: + raise ValueError( + "tt.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1" + ) + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tl.tensor(sorted_vals, type=ptr.type) + + return values + +def ext_semantic_scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + return tl.tensor(value, dtype) + +def ext_semantic_make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + return cast(value, dtype, builder) + return ext_semantic_scalar_constant(value, dtype, builder) + +def ext_semantic_make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder +) -> tensor_descriptor: + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, tl.pointer_type) + primitive_bitwidth = base.dtype.element_ty.primitive_bitwidth + if primitive_bitwidth == 1: + raise ValueError("int1 type is not supported for make_tensor_descriptor yet") + elem_size = primitive_bitwidth // 8 + contig_dim_size = _unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + strides[-1] = _unwrap_if_constexpr(strides[-1]) + if strides[-1] != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") + + shape = [ext_semantic_make_scalar(x, tl.int32, builder) for x in shape] + strides = [ext_semantic_make_scalar(x, tl.int64, builder) for x in strides] + + block_shape = _unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + desc_block_type = block_type(base.type.element_ty, block_shape) + base_handle = base.handle + is_signed_int = base.type.element_ty.is_int_signed() + + handle = builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], + [s.handle for s in strides], block_shape, is_signed_int) + return tensor_descriptor(handle, shape, strides, desc_block_type) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py new file mode 100644 index 000000000..61e0badb0 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py @@ -0,0 +1,99 @@ +def set_Autotuner_auto_profile_dir(autotuner, auto_profile_dir): + autotuner.auto_profile_dir = auto_profile_dir + +def has_spec_default_Autotuner_configs(): + return True + +def get_spec_default_Autotuner_configs(): + from triton.runtime.autotuner import Config + return Config({}) + +def ext_Autotuner_do_bench_MLIRCompilationError(exception_types): + from ..compiler.errors import MLIRCompilationError + return (MLIRCompilationError) + +def _profile(autotuner, *args, config, **meta): + from triton.testing import do_bench_npu + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**autotuner.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + autotuner.pre_hook(full_nargs) + try: + autotuner.fn.run( + *args, + **current, + ) + except Exception as e: + try: + autotuner.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `autotuner.fn.run` + raise + + autotuner.post_hook(full_nargs, exception=None) + + do_bench_npu( + kernel_call, prof_dir=autotuner.auto_profile_dir, keep_res=True + ) + +def ext_Autotuner_profile(autotuner, used_cached_result, args, kwargs): + if not used_cached_result and autotuner.auto_profile_dir is not None: + _profile(*args, config=autotuner.best_config, **kwargs) + +def set_Config_BiShengIR_options(config, bishengir_options): + # BiShengIR Options allowed for autotune + config.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True + config.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False + config.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False + config.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit + config.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 + config.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True + config.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 + config.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + +def ext_Config_all_kwargs(config): + return ( + ("multibuffer", config.multibuffer), + ("enable_hivm_auto_cv_balance", config.enable_hivm_auto_cv_balance), + ("unit_flag", config.unit_flag), + ("limit_auto_multi_buffer_only_for_local_buffer", \ + config.limit_auto_multi_buffer_only_for_local_buffer), + ("limit_auto_multi_buffer_of_local_buffer", config.limit_auto_multi_buffer_of_local_buffer), + ("set_workspace_multibuffer", config.set_workspace_multibuffer), + ("tile_mix_vector_loop", config.tile_mix_vector_loop), + ("tile_mix_cube_loop", config.tile_mix_cube_loop) + ) + +def ext_Config_to_str(res, config): + res.append(f"multibuffer: {config.multibuffer}") + res.append(f"enable_hivm_auto_cv_balance: {config.enable_hivm_auto_cv_balance}") + res.append(f"unit_flag: {config.unit_flag}") + res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ + {config.limit_auto_multi_buffer_only_for_local_buffer}") + res.append(f"limit_auto_multi_buffer_of_local_buffer: {config.limit_auto_multi_buffer_of_local_buffer}") + res.append(f"set_workspace_multibuffer: {config.set_workspace_multibuffer}") + res.append(f"tile_mix_vector_loop: {config.tile_mix_vector_loop}") + res.append(f"tile_mix_cube_loop: {config.tile_mix_cube_loop}") + +def new_AutoTilingTuner(fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction): + from triton.runtime.autotiling_tuner import AutoTilingTuner + return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, + split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, + dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py new file mode 100644 index 000000000..5dd119f27 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py @@ -0,0 +1,64 @@ +def is_set_stream_in_kwargs(kwargs): + return True if ('stream' not in kwargs.keys()) else False + +def is_stream_option_deprecated(): + return False + +def ignore_params_in_JITFunction_run(kwargs, excess_kwargs): + ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ + "allowed_dot_input_precisions", "multibuffer", "stream"] + not_work_params = [] + for k in kwargs: + if k in ignor_params: + continue + elif k in excess_kwargs: + not_work_params.append(k) + if len(not_work_params) != 0: + print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + +def set_stream_from_kwargs(kwargs, stream): + if ('stream' in kwargs.keys()): + return kwargs["stream"] + return stream + +def check_grid_size(grid_0, grid_1, grid_2): + import os + grid_all_size = grid_0 * grid_1 * grid_2 + if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": + if grid_all_size > 65535: + raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") + +def explicit_load_kernel_library(kernel): + # explicitly define run method and load kernel binary + kernel._init_handles() + +def is_JITFunction_spec_attr(): + return True + +def get_JITFunction_spec_attr(deserialized_obj): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + return AscendAttrsDescriptor.from_dict(deserialized_obj['attrs']) + +def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + +def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py new file mode 100644 index 000000000..15c856d9a --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py @@ -0,0 +1,316 @@ +import torch +import os + +def is_do_bench_npu(): + enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' + if torch.npu.is_available() and enable_bench_npu: + return True + return False + + +def collect_files(base_dir): + import pandas as pd + for root, dirs, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] + if not triton_rows.empty: + return triton_rows['Avg Time(us)'].values[0] + return float('inf') + return float('inf') + + +def collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float('inf') + + import pandas as pd + for root, _, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + if key is not None: + key_rows = df[df['OP Type'].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows['Avg Time(us)'].values[0] + return float('inf') + else: + # default: read the first row except header + return df.loc[0, 'Avg Time(us)'] + + return float('inf') + + +def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): + import torch_npu + import multiprocessing + from triton import runtime + from datetime import datetime, timezone + + # warmup kernel + fn() + torch.npu.synchronize() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + skip_first = 1 + wait = 0 + repeat = 1 + total = skip_first + (wait + warmup + active) * repeat + + if prof_dir is not None: + torch_path = prof_dir + else: + process = multiprocessing.current_process() + pid = process.pid + process_name = process.name + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") + torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + for _ in range(total): + fn() + prof.step() + torch.npu.synchronize() + + time = collect_single(torch_path) + + if not keep_res: + import shutil + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + + return time + + +def ext_do_bench_npu(fn, warmup, rep, quantiles, return_mode): + import torch + from triton.testing import _summarize_statistics + avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) + return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + + +def patch_triton_language(): + # Patch the triton language API here because triton's __init__.py + # import testing in the last stages. + from triton.language.tensor_descriptor import ( + tensor_descriptor, + tensor_descriptor_type, + ) + + from triton.language.core_ext import ( + dot, + cast, + gather, + get_element, + insert_slice, + extract_slice, + trans, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort + ) + from triton.language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 + from triton.language.math_ext import ( + umulhi, + exp, + exp2, + log, + log2, + cos, + sin, + sqrt, + sqrt_rn, + rsqrt, + div_rn, + erf, + tanh, + floor, + ceil, + _check_dtype, + fma, + ) + from triton.language.semantic_ext import ( + arange, + floordiv, + atom_red_typechecking_impl, + atomic_cas, + atomic_max, + atomic_min, + _load_legacy, + maximum, + minimum, + mod, + invert, + logical_and, + logical_or, + not_, + and_, + or_, + xor_, + minus, + dot_scaled, + ) + from triton import language + + language.cast = cast + language.dot = dot + language.flip = flip + language.sigmoid = sigmoid + language.softmax = softmax + language.gather = gather + language.insert_slice = insert_slice + language.extract_slice = extract_slice + language.get_element = get_element + language.tensor.__add__ = __add__ + language.tensor.__radd__ = __radd__ + language.tensor.__sub__ = __sub__ + language.tensor.__rsub__ = __rsub__ + language.tensor.__mul__ = __mul__ + language.tensor.__rmul__ = __rmul__ + language.tensor.__lshift__ = __lshift__ + language.tensor.__rshift__ = __rshift__ + language.trans = trans + language.parallel = parallel + language.compile_hint = compile_hint + language.sort = sort + language.multibuffer = multibuffer + language.sync_block_all = sync_block_all + language.sync_block_set = sync_block_set + language.sync_block_wait = sync_block_wait + language.make_tensor_descriptor = make_tensor_descriptor + language.tensor_descriptor = tensor_descriptor + language.tensor_descriptor_type = tensor_descriptor_type + language.load_tensor_descriptor = load_tensor_descriptor + language.store_tensor_descriptor = store_tensor_descriptor + + language.semantic.arange = arange + language.semantic.floordiv = floordiv + language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl + language.semantic.atomic_cas = atomic_cas + language.semantic.atomic_max = atomic_max + language.semantic.atomic_min = atomic_min + language.semantic._load_legacy = _load_legacy + language.semantic.maximum = maximum + language.semantic.minimum = minimum + language.semantic.invert = invert + language.semantic.logical_and = logical_and + language.semantic.logical_or = logical_or + language.semantic.mod = mod + language.semantic.not_ = not_ + language.semantic.and_ = and_ + language.semantic.or_ = or_ + language.semantic.xor_ = xor_ + language.semantic.minus = minus + language.semantic.dot_scaled = dot_scaled + + language.umulhi = umulhi + language.exp = exp + language.exp2 = exp2 + language.log = log + language.log2 = log2 + language.cos = cos + language.sin = sin + language.sqrt = sqrt + language.sqrt_rn = sqrt_rn + language.rsqrt = rsqrt + language.div_rn = div_rn + language.erf = erf + language.tanh = tanh + language.floor = floor + language.ceil = ceil + language.core.dtype.to_ir = dtype_to_ir + language.fma = fma + language.math.umulhi = umulhi + language.math.exp = exp + language.math.exp2 = exp2 + language.math.log = log + language.math.log2 = log2 + language.math.cos = cos + language.math.sin = sin + language.math.sqrt = sqrt + language.math.sqrt_rn = sqrt_rn + language.math.rsqrt = rsqrt + language.math.div_rn = div_rn + language.math.erf = erf + language.math.tanh = tanh + language.math.floor = floor + language.math.ceil = ceil + language.math._check_dtype = _check_dtype + language.math.fma = fma + language.math.isnan = language.extra.ascend.libdevice.isnan + language.math.isinf = language.extra.ascend.libdevice.isinf + language.math.reciprocal = language.extra.ascend.libdevice.reciprocal + language.math.log1p = language.extra.ascend.libdevice.log1p + language.math.relu = language.extra.ascend.libdevice.relu + language.math.tan = language.extra.ascend.libdevice.tan + language.math.atan = language.extra.ascend.libdevice.atan + language.math.tanh = language.extra.ascend.libdevice.tanh + language.math.ilogb = language.extra.ascend.libdevice.ilogb + language.math.ldexp = language.extra.ascend.libdevice.ldexp + language.math.pow = language.extra.ascend.libdevice.pow + language.math.flip = language.extra.ascend.libdevice.flip + language.math.atan2 = language.extra.ascend.libdevice.atan2 + language.math.div_rz = language.extra.ascend.libdevice.div_rz + language.math.fmod = language.extra.ascend.libdevice.fmod + language.math.trunc = language.extra.ascend.libdevice.trunc + language.math.round = language.extra.ascend.libdevice.round + language.math.finitef = finitef + language.math.isfinited = isfinited + language.math.rint = rint + language.math.atan2 = atan2 + language.extra.ascend.libdevice.umulhi = language.math.umulhi + language.extra.ascend.libdevice.exp = language.math.exp + language.extra.ascend.libdevice.exp2 = language.math.exp2 + language.extra.ascend.libdevice.log = language.math.log + language.extra.ascend.libdevice.log2 = language.math.log2 + language.extra.ascend.libdevice.cos = language.math.cos + language.extra.ascend.libdevice.sin = language.math.sin + language.extra.ascend.libdevice.sqrt = language.math.sqrt + language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn + language.extra.ascend.libdevice.rsqrt = language.math.rsqrt + language.extra.ascend.libdevice.div_rn = language.math.div_rn + language.extra.ascend.libdevice.erf = language.math.erf + language.extra.ascend.libdevice.tanh = language.math.tanh + language.extra.ascend.libdevice.floor = language.math.floor + language.extra.ascend.libdevice.ceil = language.math.ceil + language.extra.ascend.libdevice.fdiv = language.math.fdiv + language.extra.ascend.libdevice.fma = language.math.fma + language.extra.ascend.libdevice.abs = language.math.abs \ No newline at end of file diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/ascend/libdevice.py index 9098856ec..b4f057d47 100644 --- a/third_party/ascend/language/ascend/libdevice.py +++ b/third_party/ascend/language/ascend/libdevice.py @@ -1,4 +1,177 @@ +from functools import wraps +from typing import List from triton.language import core +from triton.language.math import _add_math_1arg_docstr, _add_math_2arg_docstr, _add_math_3arg_docstr +from triton.language import semantic + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + arg_type = arg.type.scalar.name + if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8: + # In Triton, int1 maps to the boolean type + arg_type = 'int1' + if arg_type not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg_type}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +@core.extern +@_check_dtype(dtypes=["int32", "uint32"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def tanh(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_tanh(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) + @core.extern def reciprocal(arg0, _builder=None): @@ -151,5 +324,5 @@ def trunc(arg0, _builder=None): def round(arg0, _builder=None): return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), + (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) \ No newline at end of file