From 260a42cd596fa6c65fcefb44c18c5e3304823ef8 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Wed, 5 Nov 2025 16:56:36 +0800 Subject: [PATCH 01/23] [Decoupling] ascend compiler.py decoupling --- python/triton/compiler/compiler.py | 46 +++++++++---------- python/triton/runtime/driver.py | 20 ++++++++ .../__init__.py | 9 ++++ .../triton/compiler/compiler.py | 32 +++++++++++++ 4 files changed, 84 insertions(+), 23 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/__init__.py create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6a8359d6f..a7bc26bd2 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -4,7 +4,6 @@ from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor -from ..backends.ascend.compiler import AscendAttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager @@ -12,7 +11,6 @@ from ..tools.disasm import get_sass # TODO: this shouldn't be here from .code_generator import ast_to_ttir -from .errors import MLIRCompilationError from pathlib import Path import re import functools @@ -87,8 +85,9 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.constants.keys(): if not isinstance(k, str): raise TypeError("Constants keys must be string") - if self.attrs is None: - self.attrs = AscendAttrsDescriptor() + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_ASTSource_attrs", self) def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] @@ -252,12 +251,11 @@ def compile(src, target=None, options=None): # cache hit! metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) - compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') - if (compile_speed_opt): - ttir_path = f"{file_name}.ttir" - if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): - # Already compile once but failed. So directly return - raise Exception("already failed once") + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("opt_ascend_compile_speed", file_name, metadata_path, fn_cache_manager) + # initialize metadata metadata = { "hash": hash, @@ -287,14 +285,10 @@ def compile(src, target=None, options=None): try: next_module = compile_ir(module, metadata) except Exception as e: - if (ext == "ttadapter"): - stage_name = "ConvertTritonIRToLinalgIR" - elif (ext == "npubin"): - stage_name = "ConvertLinalgRToBinary" - else: - stage_name = "MLIRCompile" - error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - raise MLIRCompilationError(stage_name, error_detail) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("handle_compile_error", e, ext) + ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): print(f"\nOverriding kernel with file {full_name}") @@ -406,9 +400,12 @@ def _init_handles(self): # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) - - # This mechanism introduces heavy runtime overhead. - # Commenting __getattribute__ requires explicitly calling _init_handles() + def __getattribute__(self, name): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if name == 'run' and flagtree_backend_specialization("is_CompiledKernel_getattribute_need_init_handles"): + self._init_handles() + return super().__getattribute__(name) def launch_metadata(self, grid, stream, *args): if CompiledKernel.launch_enter_hook is None: @@ -431,8 +428,11 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): - if stream is None: - stream = self.metadata.stream + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("set_CompiledKernel_metadata_stream", self, stream) + if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c3b97a764..f78c4b9e2 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -58,3 +58,23 @@ def reset_active(self): driver = DriverConfig() + + +# flagtree backend specialization +def flagtree_backend_specialization(function_name: str, *args, **kwargs): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func(*args, **kwargs) + return None + + +# flagtree backend func specialization +def flagtree_backend_func_specialization(function_name: str): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func + return None \ No newline at end of file diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py new file mode 100644 index 000000000..36af8c813 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -0,0 +1,9 @@ +from .triton.compiler.compiler import * + +__all__ = [ + 'ext_ASTSource_attrs', + 'opt_ascend_compile_speed', + 'set_CompiledKernel_metadata_stream', + 'handle_compile_error', + 'is_CompiledKernel_getattribute_need_init_handles' +] \ No newline at end of file diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py new file mode 100644 index 000000000..12f9613a6 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py @@ -0,0 +1,32 @@ +def ext_ASTSource_attrs(ast_source): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + if ast_source.attrs is None: + ast_source.attrs = AscendAttrsDescriptor() + +def opt_ascend_compile_speed(file_name, metadata_path, fn_cache_manager): + import os + compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') + if (compile_speed_opt): + ttir_path = f"{file_name}.ttir" + if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): + # Already compile once but failed. So directly return + raise Exception("already failed once") + +def set_CompiledKernel_metadata_stream(compiled_kernel, stream): + if stream is None: + return stream + return compiled_kernel.metadata.stream + +def handle_compile_error(e, ext): + from triton.compiler.errors import MLIRCompilationError + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) + raise MLIRCompilationError(stage_name, error_detail) + +def is_CompiledKernel_getattribute_need_init_handles(): + return False \ No newline at end of file From a6758966bcd6e841b1d86d62ae2dae5e0faa6e70 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Wed, 5 Nov 2025 17:38:27 +0800 Subject: [PATCH 02/23] [Decoupling] ascend errors.py decoupling --- python/triton/compiler/errors.py | 18 +++------------ python/triton/runtime/driver.py | 2 +- .../__init__.py | 6 +++-- .../triton/compiler/compiler.py | 2 +- .../triton/compiler/errors.py | 23 +++++++++++++++++++ 5 files changed, 32 insertions(+), 19 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index 5242258ad..c678a5a75 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -51,18 +51,6 @@ class UnsupportedLanguageConstruct(CompilationError): pass -class MLIRCompilationError(TritonError): - def __init__(self, stage_name: Optional[str], message: Optional[str] = None): - self.stage_name = stage_name - self.message = f"\n" \ - f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ - f"[{self.stage_name}] encounters error:\n" \ - f"{self.filter_message(message)}" \ - f"{self.format_line_delim('[ERROR][Triton][END]')}" - def __str__(self): - return self.message - def filter_message(self, message): - # Content starting from "Stack dump without symbol names" means nothing to the users - return message.split("Stack dump without symbol names")[0] - def format_line_delim(self, keyword): - return f"///------------------{keyword}------------------\n" \ No newline at end of file +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +MLIRCompilationError = flagtree_backend_specialization("ext_MLIRCompilationError") diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index f78c4b9e2..b17e6d1fb 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -77,4 +77,4 @@ def flagtree_backend_func_specialization(function_name: str): if hasattr(flagtree_backend_specialization, function_name): func = getattr(flagtree_backend_specialization, function_name) return func - return None \ No newline at end of file + return None diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 36af8c813..2a36a79a9 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -1,9 +1,11 @@ from .triton.compiler.compiler import * +from .triton.compiler.errors import * __all__ = [ 'ext_ASTSource_attrs', 'opt_ascend_compile_speed', 'set_CompiledKernel_metadata_stream', 'handle_compile_error', - 'is_CompiledKernel_getattribute_need_init_handles' -] \ No newline at end of file + 'is_CompiledKernel_getattribute_need_init_handles', + 'ext_MLIRCompilationError' +] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py index 12f9613a6..9bf881463 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py @@ -29,4 +29,4 @@ def handle_compile_error(e, ext): raise MLIRCompilationError(stage_name, error_detail) def is_CompiledKernel_getattribute_need_init_handles(): - return False \ No newline at end of file + return False diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py new file mode 100644 index 000000000..f20c74c53 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py @@ -0,0 +1,23 @@ +import importlib.util +import sys +from typing import Optional +from triton.compiler.errors import TritonError + +class MLIRCompilationError(TritonError): + def __init__(self, stage_name: Optional[str], message: Optional[str] = None): + self.stage_name = stage_name + self.message = f"\n" \ + f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ + f"[{self.stage_name}] encounters error:\n" \ + f"{self.filter_message(message)}" \ + f"{self.format_line_delim('[ERROR][Triton][END]')}" + def __str__(self): + return self.message + def filter_message(self, message): + # Content starting from "Stack dump without symbol names" means nothing to the users + return message.split("Stack dump without symbol names")[0] + def format_line_delim(self, keyword): + return f"///------------------{keyword}------------------\n" + +def ext_MLIRCompilationError(): + return MLIRCompilationError From ea86a8c6c4f7b3701d5900da4b62665787181e35 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 6 Nov 2025 11:02:11 +0800 Subject: [PATCH 03/23] [Decoupling] Add specialization import in ascend backend --- third_party/ascend/backend/driver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 8ea0873a8..603bb84cc 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -105,6 +105,9 @@ class NPUDriver(DriverBase): def __init__(self): self.utils = NPUUtils() self.launcher_cls = NPULauncher + # flagtree backend specialization + from triton.backends.ascend import flagtree_backend_specialization + self.flagtree_backend_specialization = flagtree_backend_specialization super().__init__() @classmethod From 0af9930a3fb2553c27ff3367d09ddddffbaf07b7 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 6 Nov 2025 18:39:26 +0800 Subject: [PATCH 04/23] [Decoupling] ascend code_generator.py decoupling --- python/triton/compiler/code_generator.py | 92 +++++++++---------- .../__init__.py | 15 ++- .../triton/compiler/code_generator.py | 63 +++++++++++++ 3 files changed, 122 insertions(+), 48 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 155c2824e..2928cbc14 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -155,10 +155,12 @@ def visit_Return(self, node: ast.Return) -> bool: def visit_Assign(self, node: ast.Assign) -> bool: # There couldn't be an early return + # x = ... return False def visit_AugAssign(self, node: ast.AugAssign) -> bool: # There couldn't be an early return + # x += ... return False def visit_Module(self, node: ast.Module) -> bool: @@ -168,6 +170,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: return self._visit_stmts(node.body) def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return ret = self._visit_stmts(node.body) if node.orelse: ret = ret or self._visit_stmts(node.orelse) @@ -192,6 +201,9 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module @@ -474,7 +486,10 @@ def visit_AnnAssign(self, node): return self.visit_Assign(node) def visit_Assign(self, node): - # flagtree: First, do normal assignment processing + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("anno_CodeGenerator_visit_Assign") + _names = [] if isinstance(node, ast.AnnAssign): _names += [self.visit(node.target)] @@ -498,30 +513,10 @@ def visit_Assign(self, node): not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) - - # flagtree: After normal processing, check if we need to add hint annotation - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and - isinstance(node.value, ast.Call) and - isinstance(node.value.func, ast.Attribute) and - isinstance(node.value.func.value, ast.Name) and - node.value.func.value.id == 'tl' and - node.value.func.attr == 'load'): - - # Add hint annotation to the loaded tensor(s) - for name, value in zip(names, values): - if _is_triton_value(value): - # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") - # Create hint annotation - hint_val = self.builder.get_unit_attr() - self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) def visit_AugAssign(self, node): name = node.target.id @@ -828,6 +823,8 @@ def visit_While(self, node): liveins, insert_block = sr ip, last_loc = self._get_insertion_point_and_loc() + # loop body (the after region) + # loop_block = self.builder.create_block() dummy = self.builder.create_block() self.builder.set_insertion_point_to_start(dummy) self.scf_stack.append(node) @@ -921,8 +918,11 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None - bind_sub_block = None - if IteratorClass in [language.range, language.parallel]: + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + bind_sub_block = flagtree_backend_specialization("init_bind_sub_block") + if IteratorClass in [language.range] + ([language.parallel] if flagtree_backend_specialization("is_visit_For_support_parallel") else []): iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -932,8 +932,9 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor - if (IteratorClass is language.parallel): - bind_sub_block = iterator.bind_sub_block + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block) elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -943,20 +944,10 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') - - # flagtree: After normal processing, check if we need to override bind_sub_block - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a range/for loop with bind_sub_block hint - if flagtree_hints and 'bind_sub_block' in flagtree_hints: - bind_sub_block = True - # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") - + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("check_override_bind_sub_block", self, node, bind_sub_block) + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1021,7 +1012,8 @@ def visit_For(self, node): if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) if (bind_sub_block is not None) and bind_sub_block: - for_op.set_attr("bind_sub_block", self.builder.get_bool_attr(bind_sub_block)) + #flagtree backend specialization + flagtree_backend_specialization("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) @@ -1105,7 +1097,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): generator.visit(fn.parse()) except Exception as e: # Wrap the error in the callee with the location of the call. - raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, self.cur_node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e callee_ret_type = generator.ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1149,7 +1145,11 @@ def visit_Call(self, node): # itself). But when calling a function, we raise as `from e` to # preserve the traceback of the original error, which may e.g. # be in core.py. - raise CompilationError(self.jit_fn.src, node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 2a36a79a9..a5c1087ba 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -2,10 +2,21 @@ from .triton.compiler.errors import * __all__ = [ + # compiler.compiler 'ext_ASTSource_attrs', 'opt_ascend_compile_speed', 'set_CompiledKernel_metadata_stream', 'handle_compile_error', - 'is_CompiledKernel_getattribute_need_init_handles', - 'ext_MLIRCompilationError' + 'is_CompiledKernel_getattribute_need_init_handles', + # compiler.errors + 'ext_MLIRCompilationError', + # compiler.code_generator + 'anno_CodeGenerator_visit_Assign', + 'ext_CodeGenerator_visit_Assign_hint_anno', + 'init_bind_sub_block', + 'is_visit_For_support_parallel', + 'set_bind_sub_block_when_parallel', + 'check_override_bind_sub_block', + 'forop_setattr_for_bind_sub_block', + 'need_repr_in_CodeGenerator_CompilationError' ] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py new file mode 100644 index 000000000..a5c7b794a --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py @@ -0,0 +1,63 @@ +def anno_CodeGenerator_visit_Assign(): + # flagtree: First, do normal assignment processing + return + +def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + +def init_bind_sub_block(): + return None + +def is_visit_For_support_parallel(): + return True + +def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): + import triton.language as language + if (IteratorClass is language.parallel): + return iterator.bind_sub_block + return bind_sub_block + +def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + +def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + +def need_repr_in_CodeGenerator_CompilationError(): + return True From df8b481aa788e32ebcff642f1c76f0962ad6345b Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 6 Nov 2025 18:43:18 +0800 Subject: [PATCH 05/23] [Decoupling] import code_generator specialization --- .../ascend/backend/flagtree_backend_specialization/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index a5c1087ba..e1e3e7042 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -1,5 +1,6 @@ from .triton.compiler.compiler import * from .triton.compiler.errors import * +from .triton.compiler.code_generator import * __all__ = [ # compiler.compiler From 68240fa5f2b2f35121bbbc351363b090bccaceba Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 14:32:23 +0800 Subject: [PATCH 06/23] [Decoupling] ascend jit.py decoupling --- python/triton/runtime/jit.py | 77 +++++++++---------- .../__init__.py | 14 +++- .../triton/runtime/jit.py | 62 +++++++++++++++ 3 files changed, 111 insertions(+), 42 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 08422611f..aa7d8b800 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -6,14 +6,11 @@ import os import re import textwrap -import tokenize from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver -from ..backends.ascend.compiler import AscendAttrsDescriptor from types import ModuleType -from io import StringIO TRITON_MODULE = __name__[:-len(".runtime.jit")] @@ -331,6 +328,7 @@ def __getitem__(self, grid) -> T: memorizes the grid. """ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) def serialize_specialization_data(name, signature, constants, attrs, options, key): @@ -568,7 +566,10 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend device = driver.active.get_current_device() - if ('stream' not in kwargs.keys()): + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_set_stream_in_kwargs", kwargs): stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() backend = make_backend(target) @@ -593,20 +594,17 @@ def run(self, *args, grid, warmup, **kwargs): # deprecated arguments assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" assert "device" not in kwargs, "device option is deprecated; current device will be used" - # assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_stream_option_deprecated"): + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: if k not in options.__dict__: raise KeyError("Keyword argument %s was specified but unrecognised" % k) - ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ - "allowed_dot_input_precisions", "multibuffer", "stream"] - not_work_params = [] - for k in kwargs: - if k in ignor_params: - continue - elif k in excess_kwargs: - not_work_params.append(k) - if len(not_work_params) != 0: - print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + + flagtree_backend_specialization("ignore_params_in_JITFunction_run", kwargs, excess_kwargs) bound_vals = tuple(bound_args.values()) @@ -660,16 +658,17 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - grid_all_size = grid_0 * grid_1 * grid_2 - if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": - if grid_all_size > 65535: - raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") - if ('stream' in kwargs.keys()): - stream = kwargs["stream"] + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("check_grid_size", grid_0, grid_1, grid_2) + stream = flagtree_backend_specialization("set_stream_from_kwargs", kwargs, stream) + # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) - # explicitly define run method and load kernel binary - kernel._init_handles() + + flagtree_backend_specialization("explicit_load_kernel_library", kernel) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) return kernel @@ -749,6 +748,7 @@ def warmup(self, *args, grid, **kwargs): def preload(self, specialization_data): from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() @@ -761,7 +761,13 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AscendAttrsDescriptor.from_dict(deserialized_obj['attrs'])) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + src = ASTSource(self, signature, constants, + flagtree_backend_specialization('get_JITFunction_spec_attr', deserialized_obj) + if flagtree_backend_specialization('is_JITFunction_spec_attr') + else AttrsDescriptor.from_dict(deserialized_obj['attrs'])) options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() @@ -775,29 +781,18 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): - # Maps line numbers to comment hints - line_flagtree_hints = {} - code_str = self.src - g = tokenize.generate_tokens(StringIO(code_str).readline) - for tok_type, tok_text, start, end, _ in g: - if tok_type == tokenize.COMMENT: - comment = tok_text.replace(" ", "").strip() - if comment.startswith('#@hint:'): - flagtree_hints = comment[len('#@hint:'):].strip() - # Record the line number of the comment - line_num = start[0] - line_flagtree_hints[line_num] = flagtree_hints - - # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + line_flagtree_hints = flagtree_backend_specialization('maps_line_numbers_to_comment_hints', self) tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) - # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints - + # flagtree backend specialization + flagtree_backend_specialization('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs): diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index e1e3e7042..3fe73cf63 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -1,6 +1,7 @@ from .triton.compiler.compiler import * from .triton.compiler.errors import * from .triton.compiler.code_generator import * +from .triton.runtime.jit import * __all__ = [ # compiler.compiler @@ -19,5 +20,16 @@ 'set_bind_sub_block_when_parallel', 'check_override_bind_sub_block', 'forop_setattr_for_bind_sub_block', - 'need_repr_in_CodeGenerator_CompilationError' + 'need_repr_in_CodeGenerator_CompilationError', + # runtime.jit + 'is_set_stream_in_kwargs', + 'is_stream_option_deprecated', + 'ignore_params_in_JITFunction_run', + 'set_stream_from_kwargs', + 'check_grid_size', + 'explicit_load_kernel_library', + 'is_JITFunction_spec_attr', + 'get_JITFunction_spec_attr', + 'maps_line_numbers_to_comment_hints', + 'attach_line_number_to_comment_mapping' ] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py new file mode 100644 index 000000000..7dc30b0dd --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py @@ -0,0 +1,62 @@ +def is_set_stream_in_kwargs(kwargs): + return True if ('stream' not in kwargs.keys()) else False + +def is_stream_option_deprecated(): + return False + +def ignore_params_in_JITFunction_run(kwargs, excess_kwargs): + ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ + "allowed_dot_input_precisions", "multibuffer", "stream"] + not_work_params = [] + for k in kwargs: + if k in ignor_params: + continue + elif k in excess_kwargs: + not_work_params.append(k) + if len(not_work_params) != 0: + print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + +def set_stream_from_kwargs(kwargs, stream): + if ('stream' in kwargs.keys()): + return kwargs["stream"] + return stream + +def check_grid_size(grid_0, grid_1, grid_2): + import os + grid_all_size = grid_0 * grid_1 * grid_2 + if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": + if grid_all_size > 65535: + raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") + +def explicit_load_kernel_library(kernel): + # explicitly define run method and load kernel binary + kernel._init_handles() + +def is_JITFunction_spec_attr(): + return True + +def get_JITFunction_spec_attr(deserialized_obj): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + return AscendAttrsDescriptor.from_dict(deserialized_obj['attrs']) + +def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + +def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints From 7a709efa1cdc80fac766b6c191b04ba927173a26 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 18:12:52 +0800 Subject: [PATCH 07/23] [Decoupling] ascend autotuner.py decoupling --- python/triton/runtime/autotuner.py | 122 ++++++------------ .../__init__.py | 13 +- .../triton/runtime/autotuner.py | 99 ++++++++++++++ 3 files changed, 153 insertions(+), 81 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 67753e129..95061c610 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -4,7 +4,7 @@ import os import time import inspect -from typing import Dict, List +from typing import Dict from .jit import KernelInterface from .errors import OutOfResources @@ -28,6 +28,7 @@ def __init__( rep=None, use_cuda_graph=False, do_bench=None, + # flagtree backend specialization auto_profile_dir=None, ): """ @@ -36,9 +37,15 @@ def __init__( 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if not configs: self.configs = [ - Config({}) + flagtree_backend_specialization('get_spec_default_Autotuner_configs') + if flagtree_backend_specialization('has_spec_default_Autotuner_configs') + else Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) ] else: self.configs = configs @@ -100,7 +107,9 @@ def _post_hook(kwargs, exception): self.num_warmups = warmup self.num_reps = rep self.use_cuda_graph = use_cuda_graph - self.auto_profile_dir = auto_profile_dir + + # flagtree backend specialization + flagtree_backend_specialization('set_Autotuner_auto_profile_dir', self, auto_profile_dir) # If we got explicitly called via the old interface, raise a warning # and proceed with the old behavior. @@ -133,7 +142,7 @@ def _post_hook(kwargs, exception): self.do_bench = do_bench def _bench(self, *args, config, **meta): - from ..compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError + from ..compiler.errors import CompileTimeAssertionFailure # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -163,46 +172,15 @@ def kernel_call(): self.post_hook(full_nargs, exception=None) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: + except (OutOfResources, CompileTimeAssertionFailure) + \ + flagtree_backend_specialization("ext_Autotuner_do_bench_MLIRCompilationError") as e: return [float("inf"), float("inf"), float("inf")] - def _profile(self, *args, config, **meta): - from triton.testing import do_bench_npu - - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} - - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(full_nargs) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(full_nargs, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise - - self.post_hook(full_nargs, exception=None) - - do_bench_npu( - kernel_call, prof_dir=self.auto_profile_dir, keep_res=True - ) - def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) used_cached_result = True @@ -234,8 +212,10 @@ def run(self, *args, **kwargs): print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - if not used_cached_result and self.auto_profile_dir is not None: - self._profile(*args, config=self.best_config, **kwargs) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Autotuner_profile', self, used_cached_result, args, kwargs) + if config.pre_hook is not None: full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} config.pre_hook(full_nargs) @@ -315,18 +295,14 @@ def __init__(self, kwargs, num_warps=None, num_stages=None, num_ctas=None, num_b self.maxnreg = maxnreg self.pre_hook = pre_hook - - # BiShengIR Options allowed for autotune - self.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True - self.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False - self.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False - self.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit - self.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 - self.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True - self.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 - self.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('set_Config_BiShengIR_options', self, bishengir_options) def all_kwargs(self): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return { **self.kwargs, **{ k: v @@ -339,17 +315,7 @@ def all_kwargs(self): ("reg_dec_producer", self.reg_dec_producer), ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), - - ("multibuffer", self.multibuffer), - ("enable_hivm_auto_cv_balance", self.enable_hivm_auto_cv_balance), - ("unit_flag", self.unit_flag), - ("limit_auto_multi_buffer_only_for_local_buffer", \ - self.limit_auto_multi_buffer_only_for_local_buffer), - ("limit_auto_multi_buffer_of_local_buffer", self.limit_auto_multi_buffer_of_local_buffer), - ("set_workspace_multibuffer", self.set_workspace_multibuffer), - ("tile_mix_vector_loop", self.tile_mix_vector_loop), - ("tile_mix_cube_loop", self.tile_mix_cube_loop), - ) if v is not None + ) + flagtree_backend_specialization('ext_Config_all_kwargs', self) if v is not None } } @@ -366,15 +332,10 @@ def __str__(self): res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") - res.append(f"multibuffer: {self.multibuffer}") - res.append(f"enable_hivm_auto_cv_balance: {self.enable_hivm_auto_cv_balance}") - res.append(f"unit_flag: {self.unit_flag}") - res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ - {self.limit_auto_multi_buffer_only_for_local_buffer}") - res.append(f"limit_auto_multi_buffer_of_local_buffer: {self.limit_auto_multi_buffer_of_local_buffer}") - res.append(f"set_workspace_multibuffer: {self.set_workspace_multibuffer}") - res.append(f"tile_mix_vector_loop: {self.tile_mix_vector_loop}") - res.append(f"tile_mix_cube_loop: {self.tile_mix_cube_loop}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Config_to_str', res, self) + return ", ".join(res) @@ -440,15 +401,16 @@ def kernel(x_ptr, x_size, **META): """ def decorator(fn): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization if split_params or tiling_params: - from .autotiling_tuner import AutoTilingTuner - return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, - split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, - dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) - else: - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + return flagtree_backend_specialization('new_AutoTilingTuner', fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction) + + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 3fe73cf63..251075978 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -2,6 +2,7 @@ from .triton.compiler.errors import * from .triton.compiler.code_generator import * from .triton.runtime.jit import * +from .triton.runtime.autotuner import * __all__ = [ # compiler.compiler @@ -31,5 +32,15 @@ 'is_JITFunction_spec_attr', 'get_JITFunction_spec_attr', 'maps_line_numbers_to_comment_hints', - 'attach_line_number_to_comment_mapping' + 'attach_line_number_to_comment_mapping', + # runtime.autotuner + 'set_Autotuner_auto_profile_dir', + 'has_spec_default_Autotuner_configs', + 'get_spec_default_Autotuner_configs', + 'ext_Autotuner_do_bench_MLIRCompilationError', + 'ext_Autotuner_profile', + 'set_Config_BiShengIR_options', + 'ext_Config_all_kwargs', + 'ext_Config_to_str', + 'new_AutoTilingTuner' ] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py new file mode 100644 index 000000000..545a9e879 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py @@ -0,0 +1,99 @@ +def set_Autotuner_auto_profile_dir(autotuner, auto_profile_dir): + autotuner.auto_profile_dir = auto_profile_dir + +def has_spec_default_Autotuner_configs(): + return True + +def get_spec_default_Autotuner_configs(): + from triton.runtime.autotuner import Config + return Config({}) + +def ext_Autotuner_do_bench_MLIRCompilationError(exception_types): + from triton.compiler.errors import MLIRCompilationError + return (MLIRCompilationError) + +def _profile(autotuner, *args, config, **meta): + from triton.testing import do_bench_npu + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**autotuner.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + autotuner.pre_hook(full_nargs) + try: + autotuner.fn.run( + *args, + **current, + ) + except Exception as e: + try: + autotuner.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `autotuner.fn.run` + raise + + autotuner.post_hook(full_nargs, exception=None) + + do_bench_npu( + kernel_call, prof_dir=autotuner.auto_profile_dir, keep_res=True + ) + +def ext_Autotuner_profile(autotuner, used_cached_result, args, kwargs): + if not used_cached_result and autotuner.auto_profile_dir is not None: + _profile(*args, config=autotuner.best_config, **kwargs) + +def set_Config_BiShengIR_options(config, bishengir_options): + # BiShengIR Options allowed for autotune + config.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True + config.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False + config.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False + config.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit + config.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 + config.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True + config.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 + config.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + +def ext_Config_all_kwargs(config): + return ( + ("multibuffer", config.multibuffer), + ("enable_hivm_auto_cv_balance", config.enable_hivm_auto_cv_balance), + ("unit_flag", config.unit_flag), + ("limit_auto_multi_buffer_only_for_local_buffer", \ + config.limit_auto_multi_buffer_only_for_local_buffer), + ("limit_auto_multi_buffer_of_local_buffer", config.limit_auto_multi_buffer_of_local_buffer), + ("set_workspace_multibuffer", config.set_workspace_multibuffer), + ("tile_mix_vector_loop", config.tile_mix_vector_loop), + ("tile_mix_cube_loop", config.tile_mix_cube_loop) + ) + +def ext_Config_to_str(res, config): + res.append(f"multibuffer: {config.multibuffer}") + res.append(f"enable_hivm_auto_cv_balance: {config.enable_hivm_auto_cv_balance}") + res.append(f"unit_flag: {config.unit_flag}") + res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ + {config.limit_auto_multi_buffer_only_for_local_buffer}") + res.append(f"limit_auto_multi_buffer_of_local_buffer: {config.limit_auto_multi_buffer_of_local_buffer}") + res.append(f"set_workspace_multibuffer: {config.set_workspace_multibuffer}") + res.append(f"tile_mix_vector_loop: {config.tile_mix_vector_loop}") + res.append(f"tile_mix_cube_loop: {config.tile_mix_cube_loop}") + +def new_AutoTilingTuner(fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction): + from triton.runtime.autotiling_tuner import AutoTilingTuner + return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, + split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, + dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) From 8b6fa01df4ddf47e81d9693bf1235d4c1346204b Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 18:39:32 +0800 Subject: [PATCH 08/23] [Decoupling] ascend _utils.py decoupling --- python/triton/language/_utils.py | 39 +++++++++--------- .../__init__.py | 8 +++- .../triton/language/_utils.py | 41 +++++++++++++++++++ 3 files changed, 68 insertions(+), 20 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py index d0ca8c734..df9eaeaaf 100644 --- a/python/triton/language/_utils.py +++ b/python/triton/language/_utils.py @@ -1,19 +1,28 @@ from __future__ import annotations -from typing import List, TYPE_CHECKING, Any, Union, Dict +from typing import List +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +from typing import TYPE_CHECKING if TYPE_CHECKING: - from .language import core - IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] - ObjPath = tuple[int, ...] + IterableType, ObjPath = flagtree_backend_specialization('get_language_utils_IterableType_ObjPath') + + +TRITON_MAX_TENSOR_NUMEL = flagtree_backend_specialization('get_triton_max_tensor_numel') + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 -TRITON_MAX_TENSOR_NUMEL = 1048576 def validate_block_shape(shape: List[int]): numel = 1 for i, d in enumerate(shape): if not isinstance(d, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if flagtree_backend_specialization('is_block_shape_check_power_of_two') and not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") numel *= d if numel > TRITON_MAX_TENSOR_NUMEL: @@ -21,19 +30,11 @@ def validate_block_shape(shape: List[int]): return numel -BITWIDTH_DICT: Dict[str, int] = { - **{f"u{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"i{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"fp{n}": n - for n in (16, 32, 64)}, - **{f"fp8{suffix}": 8 - for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, - "bf16": 16, - "void": 0, -} +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +BITWIDTH_DICT = flagtree_backend_specialization('get_language_utils_BITWIDTH_DICT') -def get_primitive_bitwidth(dtype: str) -> int: - return BITWIDTH_DICT[dtype] +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_func_specialization +get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth") diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 251075978..ddb6fc343 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -3,6 +3,7 @@ from .triton.compiler.code_generator import * from .triton.runtime.jit import * from .triton.runtime.autotuner import * +from .triton.language._utils import * __all__ = [ # compiler.compiler @@ -42,5 +43,10 @@ 'set_Config_BiShengIR_options', 'ext_Config_all_kwargs', 'ext_Config_to_str', - 'new_AutoTilingTuner' + 'new_AutoTilingTuner', + # language._utils + 'get_language_utils_IterableType_ObjPath', + 'get_triton_max_tensor_numel', + 'is_block_shape_check_power_of_two', + 'get_language_utils_BITWIDTH_DICT' ] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py new file mode 100644 index 000000000..cfbe4681a --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING, Any, Union, Dict + +def get_language_utils_IterableType_ObjPath(): + from triton.language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] + return IterableType, ObjPath + + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +def get_triton_max_tensor_numel(): + return TRITON_MAX_TENSOR_NUMEL + + +def is_block_shape_check_power_of_two(): + return False + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + + +def get_language_utils_BITWIDTH_DICT(): + return BITWIDTH_DICT + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] + + From 79adf96edcabedbc2e429610dd8c86f5eb82c56d Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 18:44:39 +0800 Subject: [PATCH 09/23] [Decouping] Fix class specialization --- python/triton/compiler/errors.py | 4 ++-- python/triton/runtime/driver.py | 9 +++++++++ .../backend/flagtree_backend_specialization/__init__.py | 2 -- .../triton/compiler/errors.py | 3 --- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index c678a5a75..6d1cc470a 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -52,5 +52,5 @@ class UnsupportedLanguageConstruct(CompilationError): # flagtree backend specialization -from triton.runtime.driver import flagtree_backend_specialization -MLIRCompilationError = flagtree_backend_specialization("ext_MLIRCompilationError") +from triton.runtime.driver import flagtree_backend_class_specialization +MLIRCompilationError = flagtree_backend_class_specialization("MLIRCompilationError") diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index b17e6d1fb..c0aa8cb9b 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -78,3 +78,12 @@ def flagtree_backend_func_specialization(function_name: str): func = getattr(flagtree_backend_specialization, function_name) return func return None + +# flagtree backend class specialization +def flagtree_backend_class_specialization(class_name: str): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, class_name): + cls = getattr(flagtree_backend_specialization, class_name) + return cls + return None diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index ddb6fc343..9bf0e6863 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -12,8 +12,6 @@ 'set_CompiledKernel_metadata_stream', 'handle_compile_error', 'is_CompiledKernel_getattribute_need_init_handles', - # compiler.errors - 'ext_MLIRCompilationError', # compiler.code_generator 'anno_CodeGenerator_visit_Assign', 'ext_CodeGenerator_visit_Assign_hint_anno', diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py index f20c74c53..b1ef43a3b 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py @@ -18,6 +18,3 @@ def filter_message(self, message): return message.split("Stack dump without symbol names")[0] def format_line_delim(self, keyword): return f"///------------------{keyword}------------------\n" - -def ext_MLIRCompilationError(): - return MLIRCompilationError From dc87390421db108426134356620bcf3af1c45114 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 19:04:02 +0800 Subject: [PATCH 10/23] [Decouping] ascend testing.py decoupling --- python/triton/testing.py | 310 +---------------- .../__init__.py | 4 + .../testing.py | 316 ++++++++++++++++++ 3 files changed, 329 insertions(+), 301 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/testing.py diff --git a/python/triton/testing.py b/python/triton/testing.py index b929ef22c..92d88baee 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,8 +1,6 @@ import functools import os import subprocess -import multiprocessing -import os import sys from contextlib import contextmanager from typing import Any, Dict, List @@ -114,10 +112,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m assert return_mode in ["min", "max", "mean", "median", "all"] import torch - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' - if torch.npu.is_available() and enable_bench_npu: - avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) - return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_do_bench_npu'): + return flagtree_backend_specialization('ext_do_bench_npu') di = runtime.driver.active.get_device_interface() @@ -164,99 +162,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) -def collect_files(base_dir): - import pandas as pd - for root, dirs, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] - if not triton_rows.empty: - return triton_rows['Avg Time(us)'].values[0] - return float('inf') - return float('inf') - - -def collect_single(base_dir: str, key: str = None) -> float: - if not os.path.exists(base_dir): - return float('inf') - - import pandas as pd - for root, _, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - if key is not None: - key_rows = df[df['OP Type'].str.startswith(key, na=False)] - if not key_rows.empty: - return key_rows['Avg Time(us)'].values[0] - return float('inf') - else: - # default: read the first row except header - return df.loc[0, 'Avg Time(us)'] - - return float('inf') - - -def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): - import torch - import torch_npu - from datetime import datetime, timezone - - # warmup kernel - fn() - torch.npu.synchronize() - - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) - skip_first = 1 - wait = 0 - repeat = 1 - total = skip_first + (wait + warmup + active) * repeat - - if prof_dir is not None: - torch_path = prof_dir - else: - process = multiprocessing.current_process() - pid = process.pid - process_name = process.name - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") - torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.NPU - ], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - for _ in range(total): - fn() - prof.step() - torch.npu.synchronize() - - time = collect_single(torch_path) - - if not keep_res: - import shutil - if os.path.exists(torch_path): - shutil.rmtree(torch_path) - - return time def assert_close(x, y, atol=None, rtol=None, err_msg=''): """ @@ -431,6 +336,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b ax.legend() ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -609,205 +515,7 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops -# Patch the triton language API here because triton's __init__.py -# import testing in the last stages. -from .language.tensor_descriptor import ( - tensor_descriptor, - tensor_descriptor_type, -) - -from .language.core_ext import ( - dot, - cast, - gather, - get_element, - insert_slice, - extract_slice, - trans, - __add__, - __radd__, - __sub__, - __rsub__, - __mul__, - __rmul__, - __lshift__, - __rshift__, - parallel, - compile_hint, - make_tensor_descriptor, - load_tensor_descriptor, - store_tensor_descriptor, - multibuffer, - sync_block_all, - sync_block_set, - sync_block_wait, - dtype_to_ir, - sort -) -from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 -from .language.math_ext import ( - umulhi, - exp, - exp2, - log, - log2, - cos, - sin, - sqrt, - sqrt_rn, - rsqrt, - div_rn, - erf, - tanh, - floor, - ceil, - _check_dtype, - fma, -) -from .language.semantic_ext import ( - arange, - floordiv, - atom_red_typechecking_impl, - atomic_cas, - atomic_max, - atomic_min, - _load_legacy, - maximum, - minimum, - mod, - invert, - logical_and, - logical_or, - not_, - and_, - or_, - xor_, - minus, - dot_scaled, -) -from . import language - -language.cast = cast -language.dot = dot -language.flip = flip -language.sigmoid = sigmoid -language.softmax = softmax -language.gather = gather -language.insert_slice = insert_slice -language.extract_slice = extract_slice -language.get_element = get_element -language.tensor.__add__ = __add__ -language.tensor.__radd__ = __radd__ -language.tensor.__sub__ = __sub__ -language.tensor.__rsub__ = __rsub__ -language.tensor.__mul__ = __mul__ -language.tensor.__rmul__ = __rmul__ -language.tensor.__lshift__ = __lshift__ -language.tensor.__rshift__ = __rshift__ -language.trans = trans -language.parallel = parallel -language.compile_hint = compile_hint -language.sort = sort -language.multibuffer = multibuffer -language.sync_block_all = sync_block_all -language.sync_block_set = sync_block_set -language.sync_block_wait = sync_block_wait -language.make_tensor_descriptor = make_tensor_descriptor -language.tensor_descriptor = tensor_descriptor -language.tensor_descriptor_type = tensor_descriptor_type -language.load_tensor_descriptor = load_tensor_descriptor -language.store_tensor_descriptor = store_tensor_descriptor - -language.semantic.arange = arange -language.semantic.floordiv = floordiv -language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl -language.semantic.atomic_cas = atomic_cas -language.semantic.atomic_max = atomic_max -language.semantic.atomic_min = atomic_min -language.semantic._load_legacy = _load_legacy -language.semantic.maximum = maximum -language.semantic.minimum = minimum -language.semantic.invert = invert -language.semantic.logical_and = logical_and -language.semantic.logical_or = logical_or -language.semantic.mod = mod -language.semantic.not_ = not_ -language.semantic.and_ = and_ -language.semantic.or_ = or_ -language.semantic.xor_ = xor_ -language.semantic.minus = minus -language.semantic.dot_scaled = dot_scaled - -language.umulhi = umulhi -language.exp = exp -language.exp2 = exp2 -language.log = log -language.log2 = log2 -language.cos = cos -language.sin = sin -language.sqrt = sqrt -language.sqrt_rn = sqrt_rn -language.rsqrt = rsqrt -language.div_rn = div_rn -language.erf = erf -language.tanh = tanh -language.floor = floor -language.ceil = ceil -language.core.dtype.to_ir = dtype_to_ir -language.fma = fma -language.math.umulhi = umulhi -language.math.exp = exp -language.math.exp2 = exp2 -language.math.log = log -language.math.log2 = log2 -language.math.cos = cos -language.math.sin = sin -language.math.sqrt = sqrt -language.math.sqrt_rn = sqrt_rn -language.math.rsqrt = rsqrt -language.math.div_rn = div_rn -language.math.erf = erf -language.math.tanh = tanh -language.math.floor = floor -language.math.ceil = ceil -language.math._check_dtype = _check_dtype -language.math.fma = fma -language.math.isnan = language.extra.ascend.libdevice.isnan -language.math.isinf = language.extra.ascend.libdevice.isinf -language.math.reciprocal = language.extra.ascend.libdevice.reciprocal -language.math.log1p = language.extra.ascend.libdevice.log1p -language.math.relu = language.extra.ascend.libdevice.relu -language.math.tan = language.extra.ascend.libdevice.tan -language.math.atan = language.extra.ascend.libdevice.atan -language.math.tanh = language.extra.ascend.libdevice.tanh -language.math.ilogb = language.extra.ascend.libdevice.ilogb -language.math.ldexp = language.extra.ascend.libdevice.ldexp -language.math.pow = language.extra.ascend.libdevice.pow -language.math.flip = language.extra.ascend.libdevice.flip -language.math.atan2 = language.extra.ascend.libdevice.atan2 -language.math.div_rz = language.extra.ascend.libdevice.div_rz -language.math.fmod = language.extra.ascend.libdevice.fmod -language.math.trunc = language.extra.ascend.libdevice.trunc -language.math.round = language.extra.ascend.libdevice.round -language.math.finitef = finitef -language.math.isfinited = isfinited -language.math.rint = rint -language.math.atan2 = atan2 -language.extra.ascend.libdevice.umulhi = language.math.umulhi -language.extra.ascend.libdevice.exp = language.math.exp -language.extra.ascend.libdevice.exp2 = language.math.exp2 -language.extra.ascend.libdevice.log = language.math.log -language.extra.ascend.libdevice.log2 = language.math.log2 -language.extra.ascend.libdevice.cos = language.math.cos -language.extra.ascend.libdevice.sin = language.math.sin -language.extra.ascend.libdevice.sqrt = language.math.sqrt -language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn -language.extra.ascend.libdevice.rsqrt = language.math.rsqrt -language.extra.ascend.libdevice.div_rn = language.math.div_rn -language.extra.ascend.libdevice.erf = language.math.erf -language.extra.ascend.libdevice.tanh = language.math.tanh -language.extra.ascend.libdevice.floor = language.math.floor -language.extra.ascend.libdevice.ceil = language.math.ceil -language.extra.ascend.libdevice.fdiv = language.math.fdiv -language.extra.ascend.libdevice.fma = language.math.fma -language.extra.ascend.libdevice.abs = language.math.abs + +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +flagtree_backend_specialization('patch_triton_language') diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 9bf0e6863..803fa7376 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -47,4 +47,8 @@ 'get_triton_max_tensor_numel', 'is_block_shape_check_power_of_two', 'get_language_utils_BITWIDTH_DICT' + # testing + 'is_do_bench_npu', + 'ext_do_bench_npu', + 'patch_triton_language' ] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/testing.py new file mode 100644 index 000000000..8a1bddd97 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/testing.py @@ -0,0 +1,316 @@ +import torch +import os + +def is_do_bench_npu(): + enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' + if torch.npu.is_available() and enable_bench_npu: + return True + return False + + +def collect_files(base_dir): + import pandas as pd + for root, dirs, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] + if not triton_rows.empty: + return triton_rows['Avg Time(us)'].values[0] + return float('inf') + return float('inf') + + +def collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float('inf') + + import pandas as pd + for root, _, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + if key is not None: + key_rows = df[df['OP Type'].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows['Avg Time(us)'].values[0] + return float('inf') + else: + # default: read the first row except header + return df.loc[0, 'Avg Time(us)'] + + return float('inf') + + +def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): + import torch_npu + import multiprocessing + from triton import runtime + from datetime import datetime, timezone + + # warmup kernel + fn() + torch.npu.synchronize() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + skip_first = 1 + wait = 0 + repeat = 1 + total = skip_first + (wait + warmup + active) * repeat + + if prof_dir is not None: + torch_path = prof_dir + else: + process = multiprocessing.current_process() + pid = process.pid + process_name = process.name + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") + torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + for _ in range(total): + fn() + prof.step() + torch.npu.synchronize() + + time = collect_single(torch_path) + + if not keep_res: + import shutil + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + + return time + + +def ext_do_bench_npu(fn, warmup, rep, quantiles, return_mode): + import torch + from triton.testing import _summarize_statistics + avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) + return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + + +def patch_triton_language(): + # Patch the triton language API here because triton's __init__.py + # import testing in the last stages. + from triton.language.tensor_descriptor import ( + tensor_descriptor, + tensor_descriptor_type, + ) + + from triton.language.core_ext import ( + dot, + cast, + gather, + get_element, + insert_slice, + extract_slice, + trans, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort + ) + from triton.language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 + from triton.language.math_ext import ( + umulhi, + exp, + exp2, + log, + log2, + cos, + sin, + sqrt, + sqrt_rn, + rsqrt, + div_rn, + erf, + tanh, + floor, + ceil, + _check_dtype, + fma, + ) + from triton.language.semantic_ext import ( + arange, + floordiv, + atom_red_typechecking_impl, + atomic_cas, + atomic_max, + atomic_min, + _load_legacy, + maximum, + minimum, + mod, + invert, + logical_and, + logical_or, + not_, + and_, + or_, + xor_, + minus, + dot_scaled, + ) + from . import language + + language.cast = cast + language.dot = dot + language.flip = flip + language.sigmoid = sigmoid + language.softmax = softmax + language.gather = gather + language.insert_slice = insert_slice + language.extract_slice = extract_slice + language.get_element = get_element + language.tensor.__add__ = __add__ + language.tensor.__radd__ = __radd__ + language.tensor.__sub__ = __sub__ + language.tensor.__rsub__ = __rsub__ + language.tensor.__mul__ = __mul__ + language.tensor.__rmul__ = __rmul__ + language.tensor.__lshift__ = __lshift__ + language.tensor.__rshift__ = __rshift__ + language.trans = trans + language.parallel = parallel + language.compile_hint = compile_hint + language.sort = sort + language.multibuffer = multibuffer + language.sync_block_all = sync_block_all + language.sync_block_set = sync_block_set + language.sync_block_wait = sync_block_wait + language.make_tensor_descriptor = make_tensor_descriptor + language.tensor_descriptor = tensor_descriptor + language.tensor_descriptor_type = tensor_descriptor_type + language.load_tensor_descriptor = load_tensor_descriptor + language.store_tensor_descriptor = store_tensor_descriptor + + language.semantic.arange = arange + language.semantic.floordiv = floordiv + language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl + language.semantic.atomic_cas = atomic_cas + language.semantic.atomic_max = atomic_max + language.semantic.atomic_min = atomic_min + language.semantic._load_legacy = _load_legacy + language.semantic.maximum = maximum + language.semantic.minimum = minimum + language.semantic.invert = invert + language.semantic.logical_and = logical_and + language.semantic.logical_or = logical_or + language.semantic.mod = mod + language.semantic.not_ = not_ + language.semantic.and_ = and_ + language.semantic.or_ = or_ + language.semantic.xor_ = xor_ + language.semantic.minus = minus + language.semantic.dot_scaled = dot_scaled + + language.umulhi = umulhi + language.exp = exp + language.exp2 = exp2 + language.log = log + language.log2 = log2 + language.cos = cos + language.sin = sin + language.sqrt = sqrt + language.sqrt_rn = sqrt_rn + language.rsqrt = rsqrt + language.div_rn = div_rn + language.erf = erf + language.tanh = tanh + language.floor = floor + language.ceil = ceil + language.core.dtype.to_ir = dtype_to_ir + language.fma = fma + language.math.umulhi = umulhi + language.math.exp = exp + language.math.exp2 = exp2 + language.math.log = log + language.math.log2 = log2 + language.math.cos = cos + language.math.sin = sin + language.math.sqrt = sqrt + language.math.sqrt_rn = sqrt_rn + language.math.rsqrt = rsqrt + language.math.div_rn = div_rn + language.math.erf = erf + language.math.tanh = tanh + language.math.floor = floor + language.math.ceil = ceil + language.math._check_dtype = _check_dtype + language.math.fma = fma + language.math.isnan = language.extra.ascend.libdevice.isnan + language.math.isinf = language.extra.ascend.libdevice.isinf + language.math.reciprocal = language.extra.ascend.libdevice.reciprocal + language.math.log1p = language.extra.ascend.libdevice.log1p + language.math.relu = language.extra.ascend.libdevice.relu + language.math.tan = language.extra.ascend.libdevice.tan + language.math.atan = language.extra.ascend.libdevice.atan + language.math.tanh = language.extra.ascend.libdevice.tanh + language.math.ilogb = language.extra.ascend.libdevice.ilogb + language.math.ldexp = language.extra.ascend.libdevice.ldexp + language.math.pow = language.extra.ascend.libdevice.pow + language.math.flip = language.extra.ascend.libdevice.flip + language.math.atan2 = language.extra.ascend.libdevice.atan2 + language.math.div_rz = language.extra.ascend.libdevice.div_rz + language.math.fmod = language.extra.ascend.libdevice.fmod + language.math.trunc = language.extra.ascend.libdevice.trunc + language.math.round = language.extra.ascend.libdevice.round + language.math.finitef = finitef + language.math.isfinited = isfinited + language.math.rint = rint + language.math.atan2 = atan2 + language.extra.ascend.libdevice.umulhi = language.math.umulhi + language.extra.ascend.libdevice.exp = language.math.exp + language.extra.ascend.libdevice.exp2 = language.math.exp2 + language.extra.ascend.libdevice.log = language.math.log + language.extra.ascend.libdevice.log2 = language.math.log2 + language.extra.ascend.libdevice.cos = language.math.cos + language.extra.ascend.libdevice.sin = language.math.sin + language.extra.ascend.libdevice.sqrt = language.math.sqrt + language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn + language.extra.ascend.libdevice.rsqrt = language.math.rsqrt + language.extra.ascend.libdevice.div_rn = language.math.div_rn + language.extra.ascend.libdevice.erf = language.math.erf + language.extra.ascend.libdevice.tanh = language.math.tanh + language.extra.ascend.libdevice.floor = language.math.floor + language.extra.ascend.libdevice.ceil = language.math.ceil + language.extra.ascend.libdevice.fdiv = language.math.fdiv + language.extra.ascend.libdevice.fma = language.math.fma + language.extra.ascend.libdevice.abs = language.math.abs \ No newline at end of file From a578de5450a85591c44b0326a512bcba0b75a999 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Fri, 7 Nov 2025 19:09:42 +0800 Subject: [PATCH 11/23] [Decoupling] Fixpath of decoupled testing.py --- .../flagtree_backend_specialization/{ => triton}/testing.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename third_party/ascend/backend/flagtree_backend_specialization/{ => triton}/testing.py (100%) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py similarity index 100% rename from third_party/ascend/backend/flagtree_backend_specialization/testing.py rename to third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py From d87876a02de44a5cee9d675e4ffdb952013e4763 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 13 Nov 2025 16:42:38 +0800 Subject: [PATCH 12/23] [Decoupling] Fix bugs in ascend python code specialization --- python/triton/runtime/driver.py | 9 ++++++--- .../backend/flagtree_backend_specialization/__init__.py | 3 ++- .../triton/compiler/code_generator.py | 2 +- .../triton/runtime/jit.py | 2 ++ .../flagtree_backend_specialization/triton/testing.py | 2 +- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c0aa8cb9b..513c585bb 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -67,7 +67,8 @@ def flagtree_backend_specialization(function_name: str, *args, **kwargs): if hasattr(flagtree_backend_specialization, function_name): func = getattr(flagtree_backend_specialization, function_name) return func(*args, **kwargs) - return None + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") # flagtree backend func specialization @@ -77,7 +78,8 @@ def flagtree_backend_func_specialization(function_name: str): if hasattr(flagtree_backend_specialization, function_name): func = getattr(flagtree_backend_specialization, function_name) return func - return None + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") # flagtree backend class specialization def flagtree_backend_class_specialization(class_name: str): @@ -86,4 +88,5 @@ def flagtree_backend_class_specialization(class_name: str): if hasattr(flagtree_backend_specialization, class_name): cls = getattr(flagtree_backend_specialization, class_name) return cls - return None + raise RuntimeError(f"{class_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 803fa7376..9cce49ef4 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -4,6 +4,7 @@ from .triton.runtime.jit import * from .triton.runtime.autotuner import * from .triton.language._utils import * +from .triton.testing import * __all__ = [ # compiler.compiler @@ -46,7 +47,7 @@ 'get_language_utils_IterableType_ObjPath', 'get_triton_max_tensor_numel', 'is_block_shape_check_power_of_two', - 'get_language_utils_BITWIDTH_DICT' + 'get_language_utils_BITWIDTH_DICT', # testing 'is_do_bench_npu', 'ext_do_bench_npu', diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py index a5c7b794a..b90a9f99e 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py @@ -43,7 +43,7 @@ def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): def check_override_bind_sub_block(code_generator, node, bind_sub_block): # flagtree: After normal processing, check if we need to override bind_sub_block - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): line_num = node.lineno # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later function_def = code_generator.jit_fn.parse() diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py index 7dc30b0dd..5dd119f27 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py @@ -57,6 +57,8 @@ def maps_line_numbers_to_comment_hints(jit_fn): # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + return line_flagtree_hints + def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): # Attach the line number to comment mapping to the function definition node tree.body[0].line_flagtree_hints = line_flagtree_hints diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py index 8a1bddd97..15c856d9a 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py @@ -188,7 +188,7 @@ def patch_triton_language(): minus, dot_scaled, ) - from . import language + from triton import language language.cast = cast language.dot = dot From 5db2069c4d8632010e6f2bf04fe56c27c91d8020 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Mon, 17 Nov 2025 21:57:55 +0800 Subject: [PATCH 13/23] [Decoupling] Locally use MLIRCompilationError in backend specialization --- python/triton/compiler/errors.py | 5 ----- python/triton/runtime/driver.py | 10 ---------- .../triton/compiler/compiler.py | 2 +- .../triton/runtime/autotuner.py | 2 +- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index 6d1cc470a..39e6c4dfb 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -49,8 +49,3 @@ class CompileTimeAssertionFailure(CompilationError): class UnsupportedLanguageConstruct(CompilationError): pass - - -# flagtree backend specialization -from triton.runtime.driver import flagtree_backend_class_specialization -MLIRCompilationError = flagtree_backend_class_specialization("MLIRCompilationError") diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 513c585bb..58228de35 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -80,13 +80,3 @@ def flagtree_backend_func_specialization(function_name: str): return func raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") - -# flagtree backend class specialization -def flagtree_backend_class_specialization(class_name: str): - if hasattr(driver.active, "flagtree_backend_specialization"): - flagtree_backend_specialization = driver.active.flagtree_backend_specialization - if hasattr(flagtree_backend_specialization, class_name): - cls = getattr(flagtree_backend_specialization, class_name) - return cls - raise RuntimeError(f"{class_name} not found in flagtree_backend_specialization") - raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py index 9bf881463..ef1426020 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py @@ -18,7 +18,7 @@ def set_CompiledKernel_metadata_stream(compiled_kernel, stream): return compiled_kernel.metadata.stream def handle_compile_error(e, ext): - from triton.compiler.errors import MLIRCompilationError + from .errors import MLIRCompilationError if (ext == "ttadapter"): stage_name = "ConvertTritonIRToLinalgIR" elif (ext == "npubin"): diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py index 545a9e879..61e0badb0 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py @@ -9,7 +9,7 @@ def get_spec_default_Autotuner_configs(): return Config({}) def ext_Autotuner_do_bench_MLIRCompilationError(exception_types): - from triton.compiler.errors import MLIRCompilationError + from ..compiler.errors import MLIRCompilationError return (MLIRCompilationError) def _profile(autotuner, *args, config, **meta): From db5f92151058a8676fabd97e7781fbc6fadb57ed Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Tue, 18 Nov 2025 10:46:51 +0800 Subject: [PATCH 14/23] [Decoupling] Fix language._utils decoupling --- python/triton/language/_utils.py | 23 +++---------------- python/triton/language/tensor_descriptor.py | 5 +++- .../__init__.py | 4 +--- .../triton/language/_utils.py | 16 +++---------- 4 files changed, 11 insertions(+), 37 deletions(-) diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py index df9eaeaaf..b89037db2 100644 --- a/python/triton/language/_utils.py +++ b/python/triton/language/_utils.py @@ -1,15 +1,6 @@ -from __future__ import annotations - from typing import List -# flagtree backend specialization -from triton.runtime.driver import flagtree_backend_specialization -from typing import TYPE_CHECKING -if TYPE_CHECKING: - IterableType, ObjPath = flagtree_backend_specialization('get_language_utils_IterableType_ObjPath') - - -TRITON_MAX_TENSOR_NUMEL = flagtree_backend_specialization('get_triton_max_tensor_numel') +TRITON_MAX_TENSOR_NUMEL = 1048576 def is_power_of_two(x): @@ -21,6 +12,8 @@ def validate_block_shape(shape: List[int]): for i, d in enumerate(shape): if not isinstance(d, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization if flagtree_backend_specialization('is_block_shape_check_power_of_two') and not is_power_of_two(d): raise ValueError(f"Shape element {i} must be a power of 2") numel *= d @@ -28,13 +21,3 @@ def validate_block_shape(shape: List[int]): if numel > TRITON_MAX_TENSOR_NUMEL: raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") return numel - - -# flagtree backend specialization -from triton.runtime.driver import flagtree_backend_specialization -BITWIDTH_DICT = flagtree_backend_specialization('get_language_utils_BITWIDTH_DICT') - - -# flagtree backend specialization -from triton.runtime.driver import flagtree_backend_func_specialization -get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth") diff --git a/python/triton/language/tensor_descriptor.py b/python/triton/language/tensor_descriptor.py index 077c889e1..83ed64a85 100644 --- a/python/triton/language/tensor_descriptor.py +++ b/python/triton/language/tensor_descriptor.py @@ -21,7 +21,7 @@ _str_to_eviction_policy, ) -from ._utils import validate_block_shape, get_primitive_bitwidth +from ._utils import validate_block_shape def _unwrap_if_constexpr(o): @@ -227,6 +227,9 @@ def __init__(self, name): name = _unwrap_if_constexpr(name) self.name = name assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_func_specialization + get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth") self.primitive_bitwidth = get_primitive_bitwidth(name) self.itemsize = self.primitive_bitwidth // 8 if name in dtype.SINT_TYPES: diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 9cce49ef4..17ced25ef 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -44,10 +44,8 @@ 'ext_Config_to_str', 'new_AutoTilingTuner', # language._utils - 'get_language_utils_IterableType_ObjPath', - 'get_triton_max_tensor_numel', 'is_block_shape_check_power_of_two', - 'get_language_utils_BITWIDTH_DICT', + 'get_primitive_bitwidth', # testing 'is_do_bench_npu', 'ext_do_bench_npu', diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py index cfbe4681a..4810b0f6c 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py @@ -1,16 +1,10 @@ -from typing import TYPE_CHECKING, Any, Union, Dict +from __future__ import annotations -def get_language_utils_IterableType_ObjPath(): +from typing import TYPE_CHECKING, Any, Union, Dict +if TYPE_CHECKING: from triton.language import core IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] ObjPath = tuple[int, ...] - return IterableType, ObjPath - - -TRITON_MAX_TENSOR_NUMEL = 1048576 - -def get_triton_max_tensor_numel(): - return TRITON_MAX_TENSOR_NUMEL def is_block_shape_check_power_of_two(): @@ -31,10 +25,6 @@ def is_block_shape_check_power_of_two(): } -def get_language_utils_BITWIDTH_DICT(): - return BITWIDTH_DICT - - def get_primitive_bitwidth(dtype: str) -> int: return BITWIDTH_DICT[dtype] From f2179fa7bc5551484582f79b256332531fc92b8a Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Tue, 18 Nov 2025 10:51:48 +0800 Subject: [PATCH 15/23] [Decoupling] Reset triton language API patch --- python/triton/testing.py | 205 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 202 insertions(+), 3 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 92d88baee..a74e2aae1 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -516,6 +516,205 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): return tflops -# flagtree backend specialization -from triton.runtime.driver import flagtree_backend_specialization -flagtree_backend_specialization('patch_triton_language') +# Patch the triton language API here because triton's __init__.py +# import testing in the last stages. +from .language.tensor_descriptor import ( + tensor_descriptor, + tensor_descriptor_type, +) + +from .language.core_ext import ( + dot, + cast, + gather, + get_element, + insert_slice, + extract_slice, + trans, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort +) +from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 +from .language.math_ext import ( + umulhi, + exp, + exp2, + log, + log2, + cos, + sin, + sqrt, + sqrt_rn, + rsqrt, + div_rn, + erf, + tanh, + floor, + ceil, + _check_dtype, + fma, +) +from .language.semantic_ext import ( + arange, + floordiv, + atom_red_typechecking_impl, + atomic_cas, + atomic_max, + atomic_min, + _load_legacy, + maximum, + minimum, + mod, + invert, + logical_and, + logical_or, + not_, + and_, + or_, + xor_, + minus, + dot_scaled, +) +from . import language + +language.cast = cast +language.dot = dot +language.flip = flip +language.sigmoid = sigmoid +language.softmax = softmax +language.gather = gather +language.insert_slice = insert_slice +language.extract_slice = extract_slice +language.get_element = get_element +language.tensor.__add__ = __add__ +language.tensor.__radd__ = __radd__ +language.tensor.__sub__ = __sub__ +language.tensor.__rsub__ = __rsub__ +language.tensor.__mul__ = __mul__ +language.tensor.__rmul__ = __rmul__ +language.tensor.__lshift__ = __lshift__ +language.tensor.__rshift__ = __rshift__ +language.trans = trans +language.parallel = parallel +language.compile_hint = compile_hint +language.sort = sort +language.multibuffer = multibuffer +language.sync_block_all = sync_block_all +language.sync_block_set = sync_block_set +language.sync_block_wait = sync_block_wait +language.make_tensor_descriptor = make_tensor_descriptor +language.tensor_descriptor = tensor_descriptor +language.tensor_descriptor_type = tensor_descriptor_type +language.load_tensor_descriptor = load_tensor_descriptor +language.store_tensor_descriptor = store_tensor_descriptor + +language.semantic.arange = arange +language.semantic.floordiv = floordiv +language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl +language.semantic.atomic_cas = atomic_cas +language.semantic.atomic_max = atomic_max +language.semantic.atomic_min = atomic_min +language.semantic._load_legacy = _load_legacy +language.semantic.maximum = maximum +language.semantic.minimum = minimum +language.semantic.invert = invert +language.semantic.logical_and = logical_and +language.semantic.logical_or = logical_or +language.semantic.mod = mod +language.semantic.not_ = not_ +language.semantic.and_ = and_ +language.semantic.or_ = or_ +language.semantic.xor_ = xor_ +language.semantic.minus = minus +language.semantic.dot_scaled = dot_scaled + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.core.dtype.to_ir = dtype_to_ir +language.fma = fma +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +language.math._check_dtype = _check_dtype +language.math.fma = fma +language.math.isnan = language.extra.ascend.libdevice.isnan +language.math.isinf = language.extra.ascend.libdevice.isinf +language.math.reciprocal = language.extra.ascend.libdevice.reciprocal +language.math.log1p = language.extra.ascend.libdevice.log1p +language.math.relu = language.extra.ascend.libdevice.relu +language.math.tan = language.extra.ascend.libdevice.tan +language.math.atan = language.extra.ascend.libdevice.atan +language.math.tanh = language.extra.ascend.libdevice.tanh +language.math.ilogb = language.extra.ascend.libdevice.ilogb +language.math.ldexp = language.extra.ascend.libdevice.ldexp +language.math.pow = language.extra.ascend.libdevice.pow +language.math.flip = language.extra.ascend.libdevice.flip +language.math.atan2 = language.extra.ascend.libdevice.atan2 +language.math.div_rz = language.extra.ascend.libdevice.div_rz +language.math.fmod = language.extra.ascend.libdevice.fmod +language.math.trunc = language.extra.ascend.libdevice.trunc +language.math.round = language.extra.ascend.libdevice.round +language.math.finitef = finitef +language.math.isfinited = isfinited +language.math.rint = rint +language.math.atan2 = atan2 +language.extra.ascend.libdevice.umulhi = language.math.umulhi +language.extra.ascend.libdevice.exp = language.math.exp +language.extra.ascend.libdevice.exp2 = language.math.exp2 +language.extra.ascend.libdevice.log = language.math.log +language.extra.ascend.libdevice.log2 = language.math.log2 +language.extra.ascend.libdevice.cos = language.math.cos +language.extra.ascend.libdevice.sin = language.math.sin +language.extra.ascend.libdevice.sqrt = language.math.sqrt +language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn +language.extra.ascend.libdevice.rsqrt = language.math.rsqrt +language.extra.ascend.libdevice.div_rn = language.math.div_rn +language.extra.ascend.libdevice.erf = language.math.erf +language.extra.ascend.libdevice.tanh = language.math.tanh +language.extra.ascend.libdevice.floor = language.math.floor +language.extra.ascend.libdevice.ceil = language.math.ceil +language.extra.ascend.libdevice.fdiv = language.math.fdiv +language.extra.ascend.libdevice.fma = language.math.fma +language.extra.ascend.libdevice.abs = language.math.abs From 020ba50039e9c5445216784d358b9278fd966339 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Wed, 19 Nov 2025 16:45:05 +0800 Subject: [PATCH 16/23] [Decoupling] Decouple semantic extension for ascend --- python/triton/language/semantic_ext.py | 25 ++- python/triton/testing.py | 41 ---- .../__init__.py | 43 ++++ .../triton/language/semantic.py | 184 ++++++++++++++++++ 4 files changed, 244 insertions(+), 49 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py diff --git a/python/triton/language/semantic_ext.py b/python/triton/language/semantic_ext.py index 8ef3e2ec9..e27dbb21e 100644 --- a/python/triton/language/semantic_ext.py +++ b/python/triton/language/semantic_ext.py @@ -44,7 +44,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: return input - + src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar if src_sca_ty == dst_sca_ty: @@ -298,12 +298,12 @@ def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) - +# FIXME: non-exist in semantic.py def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: assert index.dtype.is_int(), "index must be an integer tensor" if not src.dtype.is_floating(): raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype}") - + rank = len(src.type.shape) assert len(index.type.shape) == rank, "source and index tensors must have the same rank" @@ -319,6 +319,7 @@ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> gather = builder.create_gather(src.handle, index.handle, axis) return wrap_tensor(gather, src.type.scalar, index.type.shape) +# FIXME: non-exist in semantic.py def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: assert(len(ful.shape) == len(offsets)) assert(len(ful.shape) == len(sizes)) @@ -484,6 +485,7 @@ def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: else: raise TypeError(f"Unexpected dtype {dtype}") +# FIXME: non-exist in semantic.py def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: assert(len(ful.shape) == len(offsets)) assert(len(ful.shape) == len(sizes)) @@ -495,6 +497,7 @@ def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], st out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) return tl.tensor(out, ret_type) +# FIXME: non-exist in semantic.py def get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): if len(src.shape) != len(indice): raise ValueError("Indice's rank must be equal to src tensor's rank") @@ -583,6 +586,7 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# FIXME: non-exist in semantic.py def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): if not hint_val: hint_val = builder.get_unit_attr() @@ -600,19 +604,21 @@ def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): builder.create_annotation(ptr.handle, hint_name, hint_val) +# FIXME: non-exist in semantic.py def custom_op(builder: ir.builder, op_name: str, **kwargs): if op_name == "sync_block_all": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) elif op_name == "sync_block_set": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - + elif op_name == "sync_block_wait": return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - + raise ValueError(f"Unsupported custom op: {op_name}") +# FIXME: non-exist in semantic.py def sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): """ Triton sort 操作 @@ -692,7 +698,7 @@ def _str_to_fp_type(float_format: Optional[str]): return ir.F8F6F4TY.FP16 raise ValueError(f"Invalid float format: {float_format}.") - +# FIXME: non-exist in semantic.py def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) if triton_ty is None: @@ -720,7 +726,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) @@ -759,6 +765,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te +# FIXME: non-exist in semantic.py def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if dtype is None: raise ValueError("dtype must be specified when value is not a tensor") @@ -770,6 +777,7 @@ def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: return tl.tensor(value, dtype) +# FIXME: non-exist in semantic.py def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" @@ -777,6 +785,7 @@ def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: return scalar_constant(value, dtype, builder) +# FIXME: non-exist in semantic.py def make_tensor_descriptor( base: tl.tensor, shape: List[tl.tensor], @@ -805,7 +814,7 @@ def make_tensor_descriptor( strides[-1] = _unwrap_if_constexpr(strides[-1]) if strides[-1] != 1: raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") - + shape = [make_scalar(x, tl.int32, builder) for x in shape] strides = [make_scalar(x, tl.int64, builder) for x in strides] diff --git a/python/triton/testing.py b/python/triton/testing.py index a74e2aae1..71a4f5346 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -571,27 +571,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): _check_dtype, fma, ) -from .language.semantic_ext import ( - arange, - floordiv, - atom_red_typechecking_impl, - atomic_cas, - atomic_max, - atomic_min, - _load_legacy, - maximum, - minimum, - mod, - invert, - logical_and, - logical_or, - not_, - and_, - or_, - xor_, - minus, - dot_scaled, -) from . import language language.cast = cast @@ -625,26 +604,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): language.load_tensor_descriptor = load_tensor_descriptor language.store_tensor_descriptor = store_tensor_descriptor -language.semantic.arange = arange -language.semantic.floordiv = floordiv -language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl -language.semantic.atomic_cas = atomic_cas -language.semantic.atomic_max = atomic_max -language.semantic.atomic_min = atomic_min -language.semantic._load_legacy = _load_legacy -language.semantic.maximum = maximum -language.semantic.minimum = minimum -language.semantic.invert = invert -language.semantic.logical_and = logical_and -language.semantic.logical_or = logical_or -language.semantic.mod = mod -language.semantic.not_ = not_ -language.semantic.and_ = and_ -language.semantic.or_ = or_ -language.semantic.xor_ = xor_ -language.semantic.minus = minus -language.semantic.dot_scaled = dot_scaled - language.umulhi = umulhi language.exp = exp language.exp2 = exp2 diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 17ced25ef..3640c58dc 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -4,6 +4,7 @@ from .triton.runtime.jit import * from .triton.runtime.autotuner import * from .triton.language._utils import * +from .triton.language.semantic import * from .triton.testing import * __all__ = [ @@ -46,6 +47,48 @@ # language._utils 'is_block_shape_check_power_of_two', 'get_primitive_bitwidth', + # language.semantic + "is_arange_check_power_of_two", + "check_arange_less_than_max_numel", + "is_cast_src_dst_scalar_type_equal", + "check_unsupported_fp8_fp64", + "ext_dot_lhs_supported_type", + "ext_dot_rhs_supported_type", + "dot_check_hf32_input_precision", + "is_dot_check_max_num_imprecise_acc", + "reset_dot_max_num_imprecise_acc", + "check_was_bool_to_int8_dtype", + "check_was_bool_to_int8_dtype_and_cast", + "check_unexpected_dtype_float", + "check_unexpected_dtype_bool", + "set_load_legacy_other_input", + "cast_back_when_load_legacy_ptr_is_bool", + "set_attr_was_bool_to_int8", + "is_atomic_need_original_check", + "ext_atomic_element_typechecking", + "is_atomic_cas_need_element_bitwidth_check", + "ext_atomic_cas_element_typechecking", + "is_atomic_max_no_bitcast", + "is_atomic_min_no_bitcast", + "atomic_max_returning_tensor", + "atomic_min_returning_tensor", + "is_float_format_support_bf16", + "is_float_format_support_fp16", + "ext_dot_scaled_validate_lhs_dtype", + "ext_dot_scaled_validate_rhs_dtype", + "ext_dot_scaled_check_same_dtype", + "is_dot_scaled_need_original_check", + "ext_dot_scaled_check_lhs_rhs_format", + "dot_scaled_recheck_rhs_scale_is_none", + "dot_scaled_check_lhs_scale_is_none", + "is_dot_scaled_support_rhs_scale", + "check_dot_scaled_lhs_scale_dtype", + "check_dot_scaled_rhs_scale_dtype", + "dot_scaled_lhs_bitcast_to_fp_type", + "dot_scaled_rhs_bitcast_to_fp_type", + "check_dot_scaled_dimension", + "check_dot_scaled_pack_size", + "set_dot_scaled_lhs_scale_handle" # testing 'is_do_bench_npu', 'ext_do_bench_npu', diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py new file mode 100644 index 000000000..10212b272 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py @@ -0,0 +1,184 @@ +import triton.language as tl +from triton.language import cast +from triton.language.semantic import to_tensor, bitcast +from triton.language._utils import TRITON_MAX_TENSOR_NUMEL + +def is_arange_check_power_of_two(): + return False + +def check_arange_less_than_max_numel(range): + if range > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") + +def is_cast_src_dst_scalar_type_equal(src_sca_ty, dst_sca_ty): + if src_sca_ty == dst_sca_ty: + return True + return False + +def check_unsupported_fp8_fp64(src_sca_ty, dst_sca_ty): + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + +def ext_dot_lhs_supported_type(): + return (tl.int1) + +def ext_dot_rhs_supported_type(): + return (tl.int1) + +def dot_check_hf32_input_precision(input_precision, ir, lhs, rhs, ret_scalar_ty): + if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): + if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): + raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + +def is_dot_check_max_num_imprecise_acc(): + return False + +def reset_dot_max_num_imprecise_acc(max_num_imprecise_acc): + max_num_imprecise_acc = 0 + return max_num_imprecise_acc + +def check_was_bool_to_int8_dtype(input): + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + +def check_was_bool_to_int8_dtype_and_cast(input, builder): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + return cast(input, tl.int1, builder) + +def check_unexpected_dtype_float(input): + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + +def check_unexpected_dtype_bool(dtype): + if dtype.is_bool(): + raise TypeError(f"Unexpected dtype {dtype}") + +def set_load_legacy_other_input(builder): + return to_tensor(0, builder) + +def cast_back_when_load_legacy_ptr_is_bool(): + return False + +def set_attr_was_bool_to_int8(ret, is_bool): + if is_bool: + ret.was_bool_to_int8 = True + +def is_atomic_need_original_check(): + return False + +def ext_atomic_element_typechecking(element_ty, op): + # Add `tl.int64` restriction for NPU + if element_ty in [tl.int1, tl.int64, tl.float16, tl.float32, tl.float64, tl.bfloat16] and op in ['or', 'xor']: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32.") + if element_ty in [tl.int1, tl.int64, tl.float64, tl.bfloat16] and op == 'xchg': + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32.") + if element_ty in [tl.int1, tl.int64, tl.float64]: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32, bfloat16.") + +def is_atomic_cas_need_element_bitwidth_check(): + return False + +def ext_atomic_cas_element_typechecking(element_ty): + if element_ty in [tl.int1, tl.int8, tl.float64, tl.bfloat16]: + raise ValueError(f"atomic_cas does not support {str(element_ty)}. " + "All support dtypes are int16, int32, int64, float16, float32.") + +def is_atomic_max_no_bitcast(): + return True + +def is_atomic_min_no_bitcast(): + return True + +def atomic_max_returning_tensor(ir, ptr, val, mask, sem, scope, builder): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + +def atomic_min_returning_tensor(ir, ptr, val, mask, sem, scope, builder): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + +def is_float_format_support_bf16(): + return True + +def is_float_format_support_fp16(): + return True + +def ext_dot_scaled_validate_lhs_dtype(lhs): + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + +def ext_dot_scaled_validate_rhs_dtype(rhs): + assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + +def ext_dot_scaled_check_same_dtype(lhs, rhs): + assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" + +def is_dot_scaled_need_original_check(): + return False + +def ext_dot_scaled_check_lhs_rhs_format(lhs_format, rhs_format): + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + +def dot_scaled_recheck_rhs_scale_is_none(rhs_scale, rhs_scale_is_none): + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + return rhs_scale_is_none + +def dot_scaled_check_lhs_scale_is_none(lhs_scale): + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + return lhs_scale_is_none + +def is_dot_scaled_support_rhs_scale(): + return True + +def check_dot_scaled_lhs_scale_dtype(lhs_scale): + assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + +def check_dot_scaled_rhs_scale_dtype(rhs_scale, rhs_scale_is_none): + if not rhs_scale_is_none: + assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + +def _bitcast_to_fp_type(val, float_format, builder): + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + +def dot_scaled_lhs_bitcast_to_fp_type(lhs, lhs_format, builder): + lhs_format: str = lhs_format.value + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + return lhs + +def dot_scaled_rhs_bitcast_to_fp_type(rhs, rhs_format, builder): + rhs_format: str = rhs_format.value + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) + return rhs + +def check_dot_scaled_dimension(lhs, rhs): + assert lhs.type.shape[-1] == rhs.type.shape[-2], ( + f"lhs last dimension (columns) {lhs.shape[-1]} " + f"must equal rhs penultimate dimension (rows) {rhs.shape[-2]}" + ) + +def check_dot_scaled_pack_size(PACKED_A, K, lhs_format, lhs, rhs): + lhs_format: str = lhs_format.value + PACKED_B = 2 if lhs_format == "e2m1" else 1 + assert K * PACKED_B == PACKED_A * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + +def set_dot_scaled_lhs_scale_handle(lhs_scale, lhs_scale_is_none): + return None if lhs_scale_is_none else lhs_scale.handle From 5513bbd90319c7f752f13999d3b4b0af2f97c686 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Wed, 19 Nov 2025 16:47:45 +0800 Subject: [PATCH 17/23] [Decoupling] Decouple semantic extension for ascend (semantic.py) --- python/triton/language/semantic.py | 161 ++++++++++++++++++++++++++--- 1 file changed, 147 insertions(+), 14 deletions(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8e9f87b5e..0ac0cfafc 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -305,6 +305,12 @@ def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Numbe input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + flagtree_backend_specialization('check_was_bool_to_int8_dtype', other) + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) input = cast(input, ret_ty, builder) @@ -332,6 +338,12 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + flagtree_backend_specialization('check_was_bool_to_int8_dtype', other) + if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other floor = math.floor(fdiv(input, other, False, builder), _builder=builder) @@ -358,6 +370,9 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): x, y = binary_op_type_checking_impl(x, y, builder) dtype = x.dtype + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_bool', dtype) if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) @@ -376,6 +391,9 @@ def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): x, y = binary_op_type_checking_impl(x, y, builder) dtype = x.dtype + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_bool', dtype) if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) @@ -424,37 +442,66 @@ def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_and(input.handle, other.handle), input.type) def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_or(input.handle, other.handle), input.type) def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unexpected_dtype_float', input) + input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + other = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', other, builder) if not other.type.is_int1(): other = bitcast(other, tl.dtype("int1"), builder) return and_(input, other, builder) def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + other = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', other, builder) if not other.type.is_int1(): other = bitcast(other, tl.dtype("int1"), builder) return or_(input, other, builder) def not_(input: tl.tensor, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) + if not input.type.is_int1(): input = bitcast(input, tl.dtype("int1"), builder) return invert(input, builder) @@ -486,6 +533,11 @@ def plus(input: tl.tensor) -> tl.tensor: def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: input_sca_ty = input.type.scalar + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_was_bool_to_int8_dtype', input) + if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) @@ -493,7 +545,13 @@ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if hasattr(input, 'was_bool_to_int8'): + input = flagtree_backend_specialization('check_was_bool_to_int8_dtype_and_cast', input, builder) + input_sca_ty = input.type.scalar + flagtree_backend_specialization('check_unexpected_dtype_float', input) if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) @@ -609,8 +667,11 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: if end <= start: raise ValueError("arange's end argument must be greater than the start argument") range = end - start - if (range & (range - 1)) != 0: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_arange_check_power_of_two') and (range & (range - 1)) != 0: raise ValueError("arange's range must be a power of 2") + flagtree_backend_specialization('check_arange_less_than_max_numel', range) shape = [range] ret_ty = tl.block_type(tl.int32, shape) return tl.tensor(builder.create_make_range(start, end), ret_ty) @@ -842,6 +903,10 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_cast_src_dst_scalar_type_equal', src_sca_ty, dst_sca_ty): + return input # For fp downcasting default rounding mode should be RTNE, for all other conversions it should # not be set @@ -856,6 +921,10 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_unsupported_fp8_fp64', src_sca_ty, dst_sca_ty) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): assert builder.codegen_fns.get( "convert_custom_types") is not None, "target doesn't provide conversion for this type." @@ -1076,6 +1145,10 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ "pointers or loading a scalar. Because the compiler does not know the boundary; please " "use block pointers (defined by `make_block_ptr`) instead") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if other is None: + other = flagtree_backend_specialization('set_load_legacy_other_input', builder) # For a pointer of scalar, check the type of `mask` and `other` if not ptr.type.is_block(): if mask and mask.type.is_block(): @@ -1120,8 +1193,10 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile), dst_ty) - if is_bool: + if is_bool and flagtree_backend_specialization('cast_back_when_load_legacy_ptr_is_bool'): ret = cast(ret, tl.int1, builder) + + flagtree_backend_specialization('set_attr_was_bool_to_int8', ret, is_bool) return ret @@ -1289,8 +1364,13 @@ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: sem = _str_to_sem(sem) scope = _str_to_scope(scope) element_ty = ptr.type.scalar.element_ty - if element_ty.primitive_bitwidth not in [16, 32, 64]: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if element_ty.primitive_bitwidth not in [16, 32, 64] and flagtree_backend_specialization('is_atomic_cas_need_element_bitwidth_check'): raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + + flagtree_backend_specialization('ext_atomic_cas_element_typechecking', element_ty) + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) @@ -1301,10 +1381,16 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, if ptr.type.is_const() or ptr.type.element_ty.is_const(): raise ValueError("Cannot store to a constant pointer") element_ty = ptr.type.scalar.element_ty - if element_ty is tl.float16 and op != 'add': + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if element_ty is tl.float16 and op != 'add' and flagtree_backend_specialization('is_atomic_need_original_check'): raise ValueError("atomic_" + op + " does not support fp16") - if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16] and flagtree_backend_specialization('is_atomic_need_original_check'): raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + + # flagtree backend specialization + flagtree_backend_specialization("ext_atomic_element_typechecking", element_ty, op) + if ptr.type.is_block(): if mask is not None: mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) @@ -1334,6 +1420,12 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: else: return tl.tensor( builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_atomic_max_no_bitcast'): + return flagtree_backend_specialization('atomic_max_returning_tensor', ir, ptr, val, mask, sem, scope, builder) + # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 @@ -1373,6 +1465,12 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: else: return tl.tensor( builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_atomic_min_no_bitcast'): + return flagtree_backend_specialization('atomic_min_returning_tensor', ir, ptr, val, mask, sem, scope, builder) + # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 @@ -1463,10 +1561,12 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona # All combinations of supported fp8 x fp8 are permitted pass else: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + tl.float32) + flagtree_backend_specialization('ext_dot_lhs_supported_type'), f"Unsupported lhs dtype {lhs.dtype}" assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + tl.float32) + flagtree_backend_specialization('ext_dot_rhs_supported_type'), f"Unsupported rhs dtype {rhs.dtype}" assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): @@ -1513,6 +1613,10 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona acc_handle = acc.handle assert acc.type == ret_ty + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('dot_check_hf32_input_precision', input_precision, ir, lhs, rhs, ret_scalar_ty) + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 if max_num_imprecise_acc is None: if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): @@ -1520,9 +1624,9 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona else: max_num_imprecise_acc = 0 else: - if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + if flagtree_backend_specialization('is_dot_check_max_num_imprecise_acc') and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") - + max_num_imprecise_acc = flagtree_backend_specialization('reset_dot_max_num_imprecise_acc', max_num_imprecise_acc) return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) @@ -1538,6 +1642,14 @@ def _str_to_fp_type(float_format: Optional[str]): return ir.F8F6F4TY.E3M2 if float_format == 'e2m1': return ir.F8F6F4TY.E2M1 + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if float_format == 'bf16' and flagtree_backend_specialization('is_float_format_support_bf16'): + return ir.F8F6F4TY.BF16 + if float_format == 'fp16' and flagtree_backend_specialization('is_float_format_support_fp16'): + return ir.F8F6F4TY.FP16 + raise ValueError(f"Invalid float format: {float_format}.") @@ -1545,22 +1657,42 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_dot_scaled_validate_lhs_dtype', lhs) + flagtree_backend_specialization('ext_dot_scaled_validate_rhs_dtype', rhs) + flagtree_backend_specialization('ext_dot_scaled_check_same_dtype', lhs, rhs) + lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + assert lhs_format in ("e2m1", "e4m3", "e5m2") or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"NYI: lhs_format {lhs_format}" + assert rhs_format in ("e4m3", "e5m2") or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"NYI: rhs_format {rhs_format}" + flagtree_backend_specialization('ext_dot_scaled_check_lhs_rhs_format', lhs_format, rhs_format) + rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None - assert rhs_scale_is_none, "NYI: rhs_scale not supported" + rhs_scale_is_none = flagtree_backend_specialization('dot_scaled_recheck_rhs_scale_is_none', rhs_scale, rhs_scale_is_none) + lhs_scale_is_none = flagtree_backend_specialization('dot_scaled_check_lhs_scale_is_none', lhs_scale) + + assert rhs_scale_is_none or flagtree_backend_specialization('is_dot_scaled_support_rhs_scale'), "NYI: rhs_scale not supported" + + flagtree_backend_specialization('check_dot_scaled_lhs_scale_dtype', lhs_scale) + flagtree_backend_specialization('check_dot_scaled_rhs_scale_dtype', rhs_scale, rhs_scale_is_none) + lhs = flagtree_backend_specialization('dot_scaled_lhs_bitcast_to_fp_type') + rhs = flagtree_backend_specialization('dot_scaled_rhs_bitcast_to_fp_type') + flagtree_backend_specialization('check_dot_scaled_dimension', lhs, rhs) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] PACKED = 2 if lhs_format == "e2m1" else 1 assert K == PACKED * lhs.type.shape[ - -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" - assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + -1] or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), \ + f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + flagtree_backend_specialization('check_dot_scaled_pack_size', PACKED, K, lhs_format, lhs, rhs) + assert K >= 64 or not flagtree_backend_specialization('is_dot_scaled_need_original_check'), f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) @@ -1571,6 +1703,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, acc_handle = acc.handle assert acc.type == ret_ty rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = flagtree_backend_specialization('set_dot_scaled_lhs_scale_handle', lhs_scale, lhs_scale_is_none) return tl.tensor( builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, acc_handle), ret_ty) From ba35a8c25985c0c20598a5eb67dbd1134a9b33d2 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 20 Nov 2025 15:34:06 +0800 Subject: [PATCH 18/23] [Decoupling] Decouple semantic extension for ascend (finish) --- python/triton/language/core_ext.py | 6 +- python/triton/language/semantic.py | 66 ++++++ .../__init__.py | 12 +- .../triton/language/semantic.py | 212 +++++++++++++++++- 4 files changed, 289 insertions(+), 7 deletions(-) diff --git a/python/triton/language/core_ext.py b/python/triton/language/core_ext.py index f44af20c2..e466b0381 100644 --- a/python/triton/language/core_ext.py +++ b/python/triton/language/core_ext.py @@ -20,7 +20,7 @@ mul, ) from typing import Optional -from . import semantic_ext as semantic +from . import semantic from .tensor_descriptor import tensor_descriptor, tensor_descriptor_base @@ -319,7 +319,7 @@ def sort(ptr, dim=-1, descending=False, _builder=None): semantic.compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) return ret - + @builtin def multibuffer(src: tensor, size, _builder=None): """ @@ -475,4 +475,4 @@ def dtype_to_ir(self, builder: ir.builder) -> ir.type: elif self.name == 'fp64': return builder.get_double_ty() raise ValueError(f'fail to convert {self} to ir type') - + diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 0ac0cfafc..6f28d3efb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1927,3 +1927,69 @@ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: # Advanced block pointer type is the same as before return tl.tensor(builder.create_advance(base.handle, offsets), base.type) + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_gather', src, index, axis, builder) + + +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_insert_slice', ful, sub, offsets, sizes, strides, builder) + + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_extract_slice', ful, offsets, sizes, strides, builder) + + +def get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_get_element', src, indice, builder) + + +def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_semantic_compile_hint", ptr, hint_name, hint_val, builder) + + +def custom_op(builder: ir.builder, op_name: str, **kwargs): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_semantic_custom_op', builder, op_name, **kwargs) + + +def sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_sort', ptr, dim, descending, builder) + + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_scalar_constant', value, dtype, builder) + + +def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_make_scalar', value, dtype, builder) + + +def make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder +) -> tensor_descriptor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_semantic_make_tensor_descriptor', base, shape, strides, block_shape, builder) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index 3640c58dc..b20045449 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -88,7 +88,17 @@ "dot_scaled_rhs_bitcast_to_fp_type", "check_dot_scaled_dimension", "check_dot_scaled_pack_size", - "set_dot_scaled_lhs_scale_handle" + "set_dot_scaled_lhs_scale_handle", + "ext_semantic_gather", + "ext_semantic_insert_slice", + "ext_semantic_extract_slice", + "ext_semantic_get_element", + "ext_semantic_compile_hint", + "ext_semantic_custom_op", + "ext_semantic_sort", + "ext_semantic_scalar_constant", + "ext_semantic_make_scalar", + "ext_semantic_make_tensor_descriptor", # testing 'is_do_bench_npu', 'ext_do_bench_npu', diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py index 10212b272..69821759d 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py @@ -1,7 +1,15 @@ +from typing import List import triton.language as tl +from triton._C.libtriton import ir from triton.language import cast -from triton.language.semantic import to_tensor, bitcast +from triton.language.semantic import to_tensor, bitcast, wrap_tensor from triton.language._utils import TRITON_MAX_TENSOR_NUMEL +from triton.language.tensor_descriptor import ( + _unwrap_if_constexpr, + _unwrap_shape, + block_type, + tensor_descriptor +) def is_arange_check_power_of_two(): return False @@ -21,10 +29,10 @@ def check_unsupported_fp8_fp64(src_sca_ty, dst_sca_ty): "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) def ext_dot_lhs_supported_type(): - return (tl.int1) + return (tl.int1,) def ext_dot_rhs_supported_type(): - return (tl.int1) + return (tl.int1,) def dot_check_hf32_input_precision(input_precision, ir, lhs, rhs, ret_scalar_ty): if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): @@ -182,3 +190,201 @@ def check_dot_scaled_pack_size(PACKED_A, K, lhs_format, lhs, rhs): def set_dot_scaled_lhs_scale_handle(lhs_scale, lhs_scale_is_none): return None if lhs_scale_is_none else lhs_scale.handle + +def ext_semantic_gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + if not src.dtype.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype}") + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + +def ext_semantic_insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def ext_semantic_extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def ext_semantic_get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + new_indice = [i.handle for i in indice] + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + +def ext_semantic_compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): + if not hint_val: + hint_val = builder.get_unit_attr() + elif isinstance(hint_val, bool): + hint_val = builder.get_bool_attr(hint_val) + elif isinstance(hint_val, int): + hint_val = builder.get_int32_attr(hint_val) + elif isinstance(hint_val, tl.constexpr): + hint_val = builder.get_str_attr(hint_val.value) + elif isinstance(hint_val, list): + # only support i64 array attr for now + hint_val = builder.get_i64_array_attr(hint_val) + else: + raise ValueError(f"Unsupported hint value type: {type(hint_val)}") + builder.create_annotation(ptr.handle, hint_name, hint_val) + +def ext_semantic_custom_op(builder: ir.builder, op_name: str, **kwargs): + if op_name == "sync_block_all": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) + + elif op_name == "sync_block_set": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + elif op_name == "sync_block_wait": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + raise ValueError(f"Unsupported custom op: {op_name}") + +def ext_semantic_sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): + """ + Triton sort 操作 + + 参数: + ptr: tl.tensor,输入张量 + dim: int,排序维度,必须是尾轴(最后一维) + descending: bool 或 constexpr,是否降序 + builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + allowed_types = {tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32} + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"tt.sort only supports int8, int16, bfloat16, float16, float32, " + f"but got {ptr.type}" + ) + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("tt.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError( + f"tt.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}" + ) + dim = last_dim + else: + if dim != -1: + raise ValueError( + "tt.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1" + ) + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tl.tensor(sorted_vals, type=ptr.type) + + return values + +def ext_semantic_scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + return tl.tensor(value, dtype) + +def ext_semantic_make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + return cast(value, dtype, builder) + return ext_semantic_scalar_constant(value, dtype, builder) + +def ext_semantic_make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder +) -> tensor_descriptor: + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, tl.pointer_type) + primitive_bitwidth = base.dtype.element_ty.primitive_bitwidth + if primitive_bitwidth == 1: + raise ValueError("int1 type is not supported for make_tensor_descriptor yet") + elem_size = primitive_bitwidth // 8 + contig_dim_size = _unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + strides[-1] = _unwrap_if_constexpr(strides[-1]) + if strides[-1] != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") + + shape = [ext_semantic_make_scalar(x, tl.int32, builder) for x in shape] + strides = [ext_semantic_make_scalar(x, tl.int64, builder) for x in strides] + + block_shape = _unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + desc_block_type = block_type(base.type.element_ty, block_shape) + base_handle = base.handle + is_signed_int = base.type.element_ty.is_int_signed() + + handle = builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], + [s.handle for s in strides], block_shape, is_signed_int) + return tensor_descriptor(handle, shape, strides, desc_block_type) From a7af2125b15d0ff29b336c3dc5da70208719f110 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 20 Nov 2025 18:11:14 +0800 Subject: [PATCH 19/23] [Decoupling] Decouple language.core extension for ascend --- python/triton/language/core.py | 205 ++++++++++- python/triton/language/core_ext.py | 23 ++ python/triton/testing.py | 55 --- .../__init__.py | 29 ++ .../triton/language/core.py | 334 ++++++++++++++++++ 5 files changed, 588 insertions(+), 58 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388..95dfee352 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1292,6 +1292,11 @@ def trans(input: tensor, *dims, _builder=None): """ if not dims: dims = (1, 0) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + dims = flagtree_backend_specialization('ext_trans_unwrap_iterable', dims) + return semantic.permute(input, dims, _builder) @@ -1482,7 +1487,7 @@ def expand_dims(input, axis, _builder=None): @_tensor_member_fn @builtin -def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, overflow_mode: Optional[str] = None, _builder=None): """ Casts a tensor to the given :code:`dtype`. @@ -1497,13 +1502,25 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas :param bitcast: If true, the tensor is bitcasted to the given :code:`dtype`, instead of being numerically casted. :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional """ + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + overflow_modes = flagtree_backend_specialization('ext_cast_set_overflow_modes') + input = semantic.to_tensor(input, _builder) if isinstance(bitcast, constexpr): bitcast = bitcast.value if bitcast: return semantic.bitcast(input, dtype, _builder) - return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + # flagtree backend specialization + ret = semantic.cast(input, dtype, _builder, fp_downcast_rounding) + flagtree_backend_specialization('ext_cast_check_overflow_mode', overflow_mode, overflow_modes, ret, _builder) + return ret # ----------------------- @@ -1537,11 +1554,18 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i specified (i.e. at least one must be :code:`None`). """ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('check_dot_deprecated_param_allow_tf32', allow_tf32) + if input_precision is None: supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) - + else: + # flagtree backend specialization + flagtree_backend_specialization('check_dot_invalid_input_precision', input_precision) input_precision = _constexpr_to_value(input_precision) out_dtype = _constexpr_to_value(out_dtype) max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) @@ -1659,6 +1683,35 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): return semantic.descriptor_store(desc_pointer, value, offsets, _builder) +@builtin +def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], + _builder=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_load_tensor_descriptor', desc, offsets, _builder) + + +@builtin +def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, + _builder=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_store_tensor_descriptor', desc, offsets, value, _builder) + + +@builtin +def make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> tensor_descriptor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_make_tensor_descriptor', base, shape, strides, block_shape, _builder) + + @_tensor_member_fn @builtin def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): @@ -1737,6 +1790,57 @@ def advance(base, offsets, _builder=None): return semantic.advance(base, offsets, _builder) +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_insert_slice', ful, sub, offsets, sizes, strides, _builder, _generator) + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_extract_slice', ful, offsets, sizes, strides, _builder, _generator) + + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_get_element', src, indice, _builder, _generator) + + +@builtin +def multibuffer(src: tensor, size, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_multibuffer', src, size, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_all', mode, event_id, _builder) + + +@builtin +def sync_block_set(sender, receiver, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_set', sender, receiver, event_id, _builder) + + +@builtin +def sync_block_wait(sender, receiver, event_id, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_sync_block_wait', sender, receiver, event_id, _builder) + # ----------------------- # Atomic Memory Operations # ----------------------- @@ -1989,6 +2093,61 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No return semantic.clamp(x, min, max, propagate_nan, _builder) +@builtin +def __add__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_add', self, other, _builder) + + +@builtin +def __radd__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_radd', self, other, _builder) + + +@builtin +def __sub__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_sub', self, other, _builder) + + +@builtin +def __rsub__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rsub', self, other, _builder) + + +@builtin +def __mul__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_mul', self, other, _builder) + + +@builtin +def __rmul__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rmul', self, other, _builder) + + +@builtin +def __lshift__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_lshift', self, other, _builder) + + +@builtin +def __rshift__(self, other, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_rshift', self, other, _builder) + # ----------------------- # Reductions # ----------------------- @@ -2108,6 +2267,12 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None return rvalue, rindices +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_sort', ptr, dim, descending, _builder) + # ----------------------- # Scans # ----------------------- @@ -2184,6 +2349,13 @@ def histogram(input, num_bins, _builder=None, _generator=None): return semantic.histogram(input, num_bins, _builder) +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_gather', src, index, axis, _builder) + # ----------------------- # Compiler Hint Ops # ----------------------- @@ -2256,6 +2428,12 @@ def assume(cond, _builder=None): return semantic.assume(semantic.to_tensor(cond, _builder), _builder) +@builtin +def compile_hint(ptr, hint_name, hint_val=None, _builder=None): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_core_compile_hint', ptr, hint_name, hint_val, _builder) + # ----------------------- # Debugging functions # ----------------------- @@ -2588,6 +2766,21 @@ def __next__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") +# flagtree backend specialization +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. + This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of + iteration in this loop. Currently on 910B, max 2 vector cores could be used. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, bind_sub_block: bool = False): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + self.bind_sub_block = bind_sub_block + # ----------------------- # Extern functions # ----------------------- @@ -2692,3 +2885,9 @@ def binary_op_type_legalization(lhs, rhs, builder): def extern(fn): """A decorator for external functions.""" return builtin(fn) + + +def dtype_to_ir(self, builder: ir.builder) -> ir.type: + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return flagtree_backend_specialization('ext_core_dtype_to_ir', self, builder) diff --git a/python/triton/language/core_ext.py b/python/triton/language/core_ext.py index e466b0381..f3f4154eb 100644 --- a/python/triton/language/core_ext.py +++ b/python/triton/language/core_ext.py @@ -128,6 +128,7 @@ def dot( ) +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def gather(src, index, axis, _builder=None): @@ -143,6 +144,7 @@ def gather(src, index, axis, _builder=None): return semantic.gather(src, index, axis, _builder) +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: @@ -170,6 +172,7 @@ def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=No return out +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: @@ -193,6 +196,7 @@ def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) sub = semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) return sub +# FIXME: non-exist in core.py @_tensor_member_fn @builtin def get_element(src, indice, _builder=None, _generator=None): @@ -213,31 +217,38 @@ def get_element(src, indice, _builder=None, _generator=None): ] return semantic.get_element(src, new_indice, _builder) +# FIXME: non-exist in core.py @builtin def __add__(self, other, _builder=None): return add(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __radd__(self, other, _builder=None): return add(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __sub__(self, other, _builder=None): return sub(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __rsub__(self, other, _builder=None): return sub(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __mul__(self, other, _builder=None): return mul(self, other, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __rmul__(self, other, _builder=None): return mul(other, self, sanitize_overflow=False, _builder=_builder) +# FIXME: non-exist in core.py @builtin def __lshift__(self, other, _builder=None): if self.type.scalar.is_floating(): @@ -247,6 +258,7 @@ def __lshift__(self, other, _builder=None): return semantic.shl(self, other, _builder) +# FIXME: non-exist in core.py @builtin def __rshift__(self, other, _builder=None): if self.type.scalar.is_floating(): @@ -259,6 +271,7 @@ def __rshift__(self, other, _builder=None): return semantic.lshr(self, other, _builder) +# FIXME: non-exist in core.py class parallel(range): """ Iterator that counts upward forever, with parallel execution semantics. @@ -274,6 +287,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact self.bind_sub_block = bind_sub_block +# FIXME: non-exist in core.py @builtin def compile_hint(ptr, hint_name, hint_val=None, _builder=None): def _unwrap(val): @@ -289,6 +303,7 @@ def _unwrap(val): semantic.compile_hint(ptr, hint_name, hint_val, _builder) +# FIXME: non-exist in core.py @builtin def sort(ptr, dim=-1, descending=False, _builder=None): """ @@ -320,6 +335,7 @@ def sort(ptr, dim=-1, descending=False, _builder=None): return ret +# FIXME: non-exist in core.py @builtin def multibuffer(src: tensor, size, _builder=None): """ @@ -332,6 +348,7 @@ def multibuffer(src: tensor, size, _builder=None): semantic.compile_hint(src, "multi_buffer", buffer_size, _builder) +# FIXME: non-exist in core.py @builtin def sync_block_all(mode, event_id, _builder=None): mode = _constexpr_to_value(mode) @@ -342,6 +359,7 @@ def sync_block_all(mode, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) +# FIXME: non-exist in core.py @builtin def sync_block_set(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) @@ -355,6 +373,7 @@ def sync_block_set(sender, receiver, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) +# FIXME: non-exist in core.py @builtin def sync_block_wait(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) @@ -368,6 +387,7 @@ def sync_block_wait(sender, receiver, event_id, _builder=None): semantic.custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) +# FIXME: non-exist in core.py @builtin def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], _builder=None) -> tensor: @@ -375,6 +395,7 @@ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union return desc.load(offsets, _builder=_builder) +# FIXME: non-exist in core.py @builtin def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, _builder=None) -> tensor: @@ -382,6 +403,7 @@ def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Unio return desc.store(offsets, value, _builder=_builder) +# FIXME: non-exist in core.py @builtin def make_tensor_descriptor( base: tensor, @@ -440,6 +462,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) +# FIXME: non-exist in core.py def dtype_to_ir(self, builder: ir.builder) -> ir.type: if self.name.startswith("fp8"): raise ValueError(f'unexpected type fp8.') diff --git a/python/triton/testing.py b/python/triton/testing.py index 71a4f5346..b3bd7e7d8 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -522,35 +522,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): tensor_descriptor, tensor_descriptor_type, ) - -from .language.core_ext import ( - dot, - cast, - gather, - get_element, - insert_slice, - extract_slice, - trans, - __add__, - __radd__, - __sub__, - __rsub__, - __mul__, - __rmul__, - __lshift__, - __rshift__, - parallel, - compile_hint, - make_tensor_descriptor, - load_tensor_descriptor, - store_tensor_descriptor, - multibuffer, - sync_block_all, - sync_block_set, - sync_block_wait, - dtype_to_ir, - sort -) from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 from .language.math_ext import ( umulhi, @@ -573,36 +544,11 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): ) from . import language -language.cast = cast -language.dot = dot language.flip = flip language.sigmoid = sigmoid language.softmax = softmax -language.gather = gather -language.insert_slice = insert_slice -language.extract_slice = extract_slice -language.get_element = get_element -language.tensor.__add__ = __add__ -language.tensor.__radd__ = __radd__ -language.tensor.__sub__ = __sub__ -language.tensor.__rsub__ = __rsub__ -language.tensor.__mul__ = __mul__ -language.tensor.__rmul__ = __rmul__ -language.tensor.__lshift__ = __lshift__ -language.tensor.__rshift__ = __rshift__ -language.trans = trans -language.parallel = parallel -language.compile_hint = compile_hint -language.sort = sort -language.multibuffer = multibuffer -language.sync_block_all = sync_block_all -language.sync_block_set = sync_block_set -language.sync_block_wait = sync_block_wait -language.make_tensor_descriptor = make_tensor_descriptor language.tensor_descriptor = tensor_descriptor language.tensor_descriptor_type = tensor_descriptor_type -language.load_tensor_descriptor = load_tensor_descriptor -language.store_tensor_descriptor = store_tensor_descriptor language.umulhi = umulhi language.exp = exp @@ -619,7 +565,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): language.tanh = tanh language.floor = floor language.ceil = ceil -language.core.dtype.to_ir = dtype_to_ir language.fma = fma language.math.umulhi = umulhi language.math.exp = exp diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py index b20045449..3c7ddc671 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -4,6 +4,7 @@ from .triton.runtime.jit import * from .triton.runtime.autotuner import * from .triton.language._utils import * +from .triton.language.core import * from .triton.language.semantic import * from .triton.testing import * @@ -47,6 +48,34 @@ # language._utils 'is_block_shape_check_power_of_two', 'get_primitive_bitwidth', + # language.core + "ext_cast_set_overflow_modes", + "ext_cast_check_overflow_mode", + "ext_trans_unwrap_iterable", + "check_dot_deprecated_param_allow_tf32", + "check_dot_invalid_input_precision", + "ext_core_gather", + "ext_core_insert_slice", + "ext_core_extract_slice", + "ext_core_get_element", + "ext_core_add", + "ext_core_radd", + "ext_core_sub", + "ext_core_rsub", + "ext_core_mul", + "ext_core_rmul", + "ext_core_lshift", + "ext_core_rshift", + "ext_core_compile_hint", + "ext_core_sort", + "ext_core_multibuffer", + "ext_core_sync_block_all", + "ext_core_sync_block_set", + "ext_core_sync_block_wait", + "ext_core_load_tensor_descriptor", + "ext_core_store_tensor_descriptor", + "ext_core_make_tensor_descriptor", + "ext_core_dtype_to_ir", # language.semantic "is_arange_check_power_of_two", "check_arange_less_than_max_numel", diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py new file mode 100644 index 000000000..3bedb6383 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py @@ -0,0 +1,334 @@ +from typing import List, Sequence, Union +from triton._C.libtriton import ir +import triton.language.semantic as semantic +from triton.language.core import ( + _unwrap_iterable, + _constexpr_to_value, + constexpr, + tensor, + check_bit_width, + _unwrap_if_constexpr, + add, + sub, + mul, +) + +from triton.language.tensor_descriptor import tensor_descriptor, tensor_descriptor_base + +def ext_cast_set_overflow_modes(): + return ["trunc", "saturate"] + +def ext_cast_check_overflow_mode(overflow_mode, overflow_modes, ret, _builder): + if overflow_mode is not None: + if overflow_mode in overflow_modes: + semantic.compile_hint(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + +def ext_trans_unwrap_iterable(dims): + return _unwrap_iterable(dims) + +def check_dot_deprecated_param_allow_tf32(allow_tf32): + assert ( + not allow_tf32 + ), "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + +def check_dot_invalid_input_precision(input_precision): + assert input_precision not in [ + "tf32", + "tf32x3", + ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + +def ext_core_gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + +def ext_core_insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + out = semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) + return out + +def ext_core_extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + sub = semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) + return sub + +def ext_core_get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + assert len(src.shape) > 0 + new_indice = [ + semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i + for i in indice + ] + return semantic.get_element(src, new_indice, _builder) + +def ext_core_add(self, other, _builder=None): + return add(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_radd(self, other, _builder=None): + return add(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_sub(self, other, _builder=None): + return sub(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_rsub(self, other, _builder=None): + return sub(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_mul(self, other, _builder=None): + return mul(self, other, sanitize_overflow=False, _builder=_builder) + +def ext_core_rmul(self, other, _builder=None): + return mul(other, self, sanitize_overflow=False, _builder=_builder) + +def ext_core_lshift(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return semantic.shl(self, other, _builder) + +def ext_core_rshift(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + other = _unwrap_if_constexpr(other) + check_bit_width(self, other) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + +def ext_core_compile_hint(ptr, hint_name, hint_val=None, _builder=None): + def _unwrap(val): + return _unwrap_if_constexpr(val) if val else val + + hint_name = _constexpr_to_value(hint_name) + assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" + if isinstance(hint_val, list): + hint_val = [_unwrap(val) for val in hint_val] + else: + hint_val = _unwrap(hint_val) + hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val + semantic.compile_hint(ptr, hint_name, hint_val, _builder) + +def ext_core_sort(ptr, dim=-1, descending=False, _builder=None): + """ + Triton sort 前端接口 + + 参数: + ptr: tl.tensor,输入张量 + dim: int 或 tl.constexpr[int],排序维度 + descending: bool 或 tl.constexpr[bool],是否降序 + _builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = semantic.sort(ptr, dim, descending, _builder) + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + semantic.compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + +def ext_core_multibuffer(src: tensor, size, _builder=None): + """ + Set multi_buffer for an existing tensor + :src: tensor set to bufferize multiple time + :size: number of copies + """ + buffer_size = _constexpr_to_value(size) + assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" + semantic.compile_hint(src, "multi_buffer", buffer_size, _builder) + +def ext_core_sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode == "all_cube" or mode == "all_vector" or mode == "all", f"ERROR: mode = {mode}, only supports all_cube/all_vector/all" + semantic.custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) + +def ext_core_sync_block_set(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) + +def ext_core_sync_block_wait(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) + +def ext_core_load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], + _builder=None) -> tensor: + """Load a block of data from a tensor descriptor.""" + return desc.load(offsets, _builder=_builder) + +def ext_core_store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, + _builder=None) -> tensor: + """Store a block of data to a tensor descriptor.""" + return desc.store(offsets, value, _builder=_builder) + +def ext_core_make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> tensor_descriptor: + """Make a tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M // M_BLOCK, N // N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) + +def ext_core_dtype_to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + raise ValueError(f'unexpected type fp8.') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') From 42f0bcea14397d6be240ee510ce618a24cc4307d Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 20 Nov 2025 18:37:44 +0800 Subject: [PATCH 20/23] [Decoupling] Decouple language.core extension for ascend (language.__ini__.py) --- python/triton/language/__init__.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6502a5348..773a52737 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -105,6 +105,29 @@ view, void, where, + gather, + get_element, + insert_slice, + extract_slice, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort ) from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, ceil) @@ -152,6 +175,7 @@ "cdiv", "ceil", "clamp", + "compile_hint", "const", "constexpr", "cos", @@ -164,11 +188,13 @@ "dot", "dot_scaled", "dtype", + "dtype_to_ir", "erf", "exp", "exp2", "expand_dims", "extra", + "extract_slice", "fdiv", "flip", "float16", @@ -183,8 +209,11 @@ "fma", "full", "function_type", + "gather", + "get_element", "histogram", "inline_asm_elementwise", + "insert_slice", "interleave", "int1", "int16", @@ -194,9 +223,11 @@ "ir", "join", "load", + "load_tensor_descriptor", "log", "log2", "make_block_ptr", + "make_tensor_descriptor", "math", "max", "max_constancy", @@ -204,9 +235,11 @@ "maximum", "min", "minimum", + "multibuffer", "multiple_of", "num_programs", "pair_uniform_to_normal", + "parallel", "permute", "philox", "philox_impl", @@ -236,8 +269,12 @@ "static_print", "static_range", "store", + "store_tensor_descriptor", "sum", "swizzle2d", + "sync_block_all", + "sync_block_set", + "sync_block_wait", "tensor", "trans", "triton", @@ -253,6 +290,14 @@ "xor_sum", "zeros", "zeros_like", + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__lshift__", + "__rshift__" ] From 57c0717f276492333e096d8412caa3c8b8e1e189 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Tue, 25 Nov 2025 15:08:05 +0800 Subject: [PATCH 21/23] [Decoupling] Decouple language.math extension for ascend --- python/triton/testing.py | 118 ++++-------- .../ascend/language/ascend/libdevice.py | 175 +++++++++++++++++- 2 files changed, 206 insertions(+), 87 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index b3bd7e7d8..46f01e784 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -523,25 +523,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): tensor_descriptor_type, ) from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 -from .language.math_ext import ( - umulhi, - exp, - exp2, - log, - log2, - cos, - sin, - sqrt, - sqrt_rn, - rsqrt, - div_rn, - erf, - tanh, - floor, - ceil, - _check_dtype, - fma, -) from . import language language.flip = flip @@ -550,75 +531,40 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): language.tensor_descriptor = tensor_descriptor language.tensor_descriptor_type = tensor_descriptor_type -language.umulhi = umulhi -language.exp = exp -language.exp2 = exp2 -language.log = log -language.log2 = log2 -language.cos = cos -language.sin = sin -language.sqrt = sqrt -language.sqrt_rn = sqrt_rn -language.rsqrt = rsqrt -language.div_rn = div_rn -language.erf = erf -language.tanh = tanh -language.floor = floor -language.ceil = ceil -language.fma = fma -language.math.umulhi = umulhi -language.math.exp = exp -language.math.exp2 = exp2 -language.math.log = log -language.math.log2 = log2 -language.math.cos = cos -language.math.sin = sin -language.math.sqrt = sqrt -language.math.sqrt_rn = sqrt_rn -language.math.rsqrt = rsqrt -language.math.div_rn = div_rn -language.math.erf = erf -language.math.tanh = tanh -language.math.floor = floor -language.math.ceil = ceil -language.math._check_dtype = _check_dtype -language.math.fma = fma -language.math.isnan = language.extra.ascend.libdevice.isnan -language.math.isinf = language.extra.ascend.libdevice.isinf -language.math.reciprocal = language.extra.ascend.libdevice.reciprocal -language.math.log1p = language.extra.ascend.libdevice.log1p -language.math.relu = language.extra.ascend.libdevice.relu -language.math.tan = language.extra.ascend.libdevice.tan -language.math.atan = language.extra.ascend.libdevice.atan +language.umulhi = language.extra.ascend.libdevice.umulhi +language.exp = language.extra.ascend.libdevice.exp +language.exp2 = language.extra.ascend.libdevice.exp2 +language.log = language.extra.ascend.libdevice.log +language.log2 = language.extra.ascend.libdevice.log2 +language.cos = language.extra.ascend.libdevice.cos +language.sin = language.extra.ascend.libdevice.sin +language.sqrt = language.extra.ascend.libdevice.sqrt +language.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn +language.rsqrt = language.extra.ascend.libdevice.rsqrt +language.div_rn = language.extra.ascend.libdevice.div_rn +language.erf = language.extra.ascend.libdevice.erf +language.tanh = language.extra.ascend.libdevice.tanh +language.floor = language.extra.ascend.libdevice.floor +language.ceil = language.extra.ascend.libdevice.ceil +language.fma = language.extra.ascend.libdevice.fma +language.math.umulhi = language.extra.ascend.libdevice.umulhi +language.math.exp = language.extra.ascend.libdevice.exp +language.math.exp2 = language.extra.ascend.libdevice.exp2 +language.math.log = language.extra.ascend.libdevice.log +language.math.log2 = language.extra.ascend.libdevice.log2 +language.math.cos = language.extra.ascend.libdevice.cos +language.math.sin = language.extra.ascend.libdevice.sin +language.math.sqrt = language.extra.ascend.libdevice.sqrt +language.math.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn +language.math.rsqrt = language.extra.ascend.libdevice.rsqrt +language.math.div_rn = language.extra.ascend.libdevice.div_rn +language.math.erf = language.extra.ascend.libdevice.erf language.math.tanh = language.extra.ascend.libdevice.tanh -language.math.ilogb = language.extra.ascend.libdevice.ilogb -language.math.ldexp = language.extra.ascend.libdevice.ldexp -language.math.pow = language.extra.ascend.libdevice.pow -language.math.flip = language.extra.ascend.libdevice.flip -language.math.atan2 = language.extra.ascend.libdevice.atan2 -language.math.div_rz = language.extra.ascend.libdevice.div_rz -language.math.fmod = language.extra.ascend.libdevice.fmod -language.math.trunc = language.extra.ascend.libdevice.trunc -language.math.round = language.extra.ascend.libdevice.round +language.math.floor = language.extra.ascend.libdevice.floor +language.math.ceil = language.extra.ascend.libdevice.ceil +language.math._check_dtype = language.extra.ascend.libdevice._check_dtype +language.math.fma = language.extra.ascend.libdevice.fma language.math.finitef = finitef language.math.isfinited = isfinited language.math.rint = rint language.math.atan2 = atan2 -language.extra.ascend.libdevice.umulhi = language.math.umulhi -language.extra.ascend.libdevice.exp = language.math.exp -language.extra.ascend.libdevice.exp2 = language.math.exp2 -language.extra.ascend.libdevice.log = language.math.log -language.extra.ascend.libdevice.log2 = language.math.log2 -language.extra.ascend.libdevice.cos = language.math.cos -language.extra.ascend.libdevice.sin = language.math.sin -language.extra.ascend.libdevice.sqrt = language.math.sqrt -language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn -language.extra.ascend.libdevice.rsqrt = language.math.rsqrt -language.extra.ascend.libdevice.div_rn = language.math.div_rn -language.extra.ascend.libdevice.erf = language.math.erf -language.extra.ascend.libdevice.tanh = language.math.tanh -language.extra.ascend.libdevice.floor = language.math.floor -language.extra.ascend.libdevice.ceil = language.math.ceil -language.extra.ascend.libdevice.fdiv = language.math.fdiv -language.extra.ascend.libdevice.fma = language.math.fma -language.extra.ascend.libdevice.abs = language.math.abs diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/ascend/libdevice.py index 9098856ec..b4f057d47 100644 --- a/third_party/ascend/language/ascend/libdevice.py +++ b/third_party/ascend/language/ascend/libdevice.py @@ -1,4 +1,177 @@ +from functools import wraps +from typing import List from triton.language import core +from triton.language.math import _add_math_1arg_docstr, _add_math_2arg_docstr, _add_math_3arg_docstr +from triton.language import semantic + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + arg_type = arg.type.scalar.name + if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8: + # In Triton, int1 maps to the boolean type + arg_type = 'int1' + if arg_type not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg_type}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +@core.extern +@_check_dtype(dtypes=["int32", "uint32"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def tanh(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_tanh(x.handle), x.type) + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.extern +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) + @core.extern def reciprocal(arg0, _builder=None): @@ -151,5 +324,5 @@ def trunc(arg0, _builder=None): def round(arg0, _builder=None): return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), + (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) \ No newline at end of file From 93c97a9ae5250cacf161487a48a172c7dd01b31c Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Wed, 3 Dec 2025 11:32:10 +0800 Subject: [PATCH 22/23] [Decoupling] Fix dot_scaled decoupled function params --- python/triton/language/semantic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6f28d3efb..a7d0e601a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1681,8 +1681,8 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, flagtree_backend_specialization('check_dot_scaled_lhs_scale_dtype', lhs_scale) flagtree_backend_specialization('check_dot_scaled_rhs_scale_dtype', rhs_scale, rhs_scale_is_none) - lhs = flagtree_backend_specialization('dot_scaled_lhs_bitcast_to_fp_type') - rhs = flagtree_backend_specialization('dot_scaled_rhs_bitcast_to_fp_type') + lhs = flagtree_backend_specialization('dot_scaled_lhs_bitcast_to_fp_type', lhs, lhs_format, builder) + rhs = flagtree_backend_specialization('dot_scaled_rhs_bitcast_to_fp_type', rhs, rhs_format, builder) flagtree_backend_specialization('check_dot_scaled_dimension', lhs, rhs) M = lhs.type.shape[-2] From a692abc2c1fc1a8d00b7e1a839bfe34b95d02503 Mon Sep 17 00:00:00 2001 From: liuyunqi20 Date: Thu, 4 Dec 2025 11:22:05 +0800 Subject: [PATCH 23/23] [Decoupling] Decouple language.tensor_descriptor extension for ascend --- python/triton/language/core.py | 6 +- python/triton/language/semantic.py | 2 +- python/triton/testing.py | 6 - .../triton/language/core.py | 2 +- .../triton/language/semantic.py | 5 +- .../triton/language/tensor_descriptor.py | 695 ++++++++++++++++++ .../triton/testing.py | 205 ------ 7 files changed, 702 insertions(+), 219 deletions(-) create mode 100644 third_party/ascend/backend/flagtree_backend_specialization/triton/language/tensor_descriptor.py diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 95dfee352..e7260b28b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1684,7 +1684,7 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): @builtin -def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], +def load_tensor_descriptor(desc, offsets: Sequence[Union[constexpr, tensor]], _builder=None) -> tensor: # flagtree backend specialization from triton.runtime.driver import flagtree_backend_specialization @@ -1692,7 +1692,7 @@ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union @builtin -def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, +def store_tensor_descriptor(desc, offsets: Sequence[Union[constexpr, tensor]], value: tensor, _builder=None) -> tensor: # flagtree backend specialization from triton.runtime.driver import flagtree_backend_specialization @@ -1706,7 +1706,7 @@ def make_tensor_descriptor( strides: List[tensor], block_shape: List[constexpr], _builder=None, -) -> tensor_descriptor: +): # flagtree backend specialization from triton.runtime.driver import flagtree_backend_specialization return flagtree_backend_specialization('ext_core_make_tensor_descriptor', base, shape, strides, block_shape, _builder) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a7d0e601a..8d3b656dd 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1989,7 +1989,7 @@ def make_tensor_descriptor( strides: List[tl.tensor], block_shape: List[tl.constexpr], builder: ir.builder -) -> tensor_descriptor: +): # flagtree backend specialization from triton.runtime.driver import flagtree_backend_specialization return flagtree_backend_specialization('ext_semantic_make_tensor_descriptor', base, shape, strides, block_shape, builder) diff --git a/python/triton/testing.py b/python/triton/testing.py index 46f01e784..b21d4911c 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -518,18 +518,12 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): # Patch the triton language API here because triton's __init__.py # import testing in the last stages. -from .language.tensor_descriptor import ( - tensor_descriptor, - tensor_descriptor_type, -) from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 from . import language language.flip = flip language.sigmoid = sigmoid language.softmax = softmax -language.tensor_descriptor = tensor_descriptor -language.tensor_descriptor_type = tensor_descriptor_type language.umulhi = language.extra.ascend.libdevice.umulhi language.exp = language.extra.ascend.libdevice.exp diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py index 3bedb6383..76b78ce6f 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/core.py @@ -13,7 +13,7 @@ mul, ) -from triton.language.tensor_descriptor import tensor_descriptor, tensor_descriptor_base +from .tensor_descriptor import tensor_descriptor, tensor_descriptor_base def ext_cast_set_overflow_modes(): return ["trunc", "saturate"] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py index 69821759d..e08f01520 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/semantic.py @@ -1,10 +1,9 @@ from typing import List import triton.language as tl from triton._C.libtriton import ir -from triton.language import cast -from triton.language.semantic import to_tensor, bitcast, wrap_tensor +from triton.language.semantic import to_tensor, bitcast, wrap_tensor, cast from triton.language._utils import TRITON_MAX_TENSOR_NUMEL -from triton.language.tensor_descriptor import ( +from .tensor_descriptor import ( _unwrap_if_constexpr, _unwrap_shape, block_type, diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/tensor_descriptor.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/tensor_descriptor.py new file mode 100644 index 000000000..ca0fbc70c --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/tensor_descriptor.py @@ -0,0 +1,695 @@ +# TODO: When upgrading to Triton 3.4.0, remove this file, +# use the upstream Triton functions, and update core.py and semantic.py accordingly. +from __future__ import annotations + +import builtins +from typing import List, Tuple, Sequence, TypeVar +from enum import Enum + +from triton._C.libtriton import ir +from triton.language.core import ( + builtin, + constexpr, + tensor, + _value, + void as real_void, +) + +from triton.language.semantic import ( + _convert_to_ir_values, + _str_to_load_cache_modifier, + _str_to_eviction_policy, +) + +from triton.language._utils import validate_block_shape + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr(x) for x in o) + return o.value if isinstance(o, constexpr) else o + + +def _unwrap_shape(shape): + shape = _unwrap_if_constexpr(shape) + return [_unwrap_if_constexpr(s) for s in shape] + + +def _normalize_tuple(t): + normalized_tuple = _unwrap_if_constexpr(t) + if isinstance(normalized_tuple, (list, builtins.tuple)): + normalized_tuple = tuple(normalized_tuple) + return normalized_tuple + + +def descriptor_load(desc: tensor_descriptor_base, offsets, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tensor: + assert isinstance(desc, tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tensor(x, desc.block_type) + + +def validate_store_like(desc: tensor_descriptor_base, value: tensor, offsets) -> None: + assert isinstance(desc, tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + +def descriptor_store(desc: tensor_descriptor_base, value: tensor, offsets, builder: ir.builder) -> tensor: + validate_store_like(desc, value, offsets) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), real_void) + + + +class base_value(_value): + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + type: base_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError + + +class base_type: + + def __eq__(self, other): + raise NotImplementedError("Types must implement __eq__") + + def __ne__(self, other): + return not (self == other) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError + + def mangle(self) -> str: + raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + raise NotImplementedError + + +class tuple(base_value): + + def __init__(self, args: Sequence, type: tuple_type = None): + self.values = [i for i in args] + + def get_type(x): + if isinstance(x, dtype): + return dtype + if isinstance(x, (int, float)): + return constexpr + return x.type + + self.type = type or tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + def __getattr__(self, name): + return self.values[self.type.fields.index(name)] + + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + other = _normalize_tuple(other) + return tuple(self.values + other.values) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + other = _normalize_tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + print("[debug]tuple _flatten_ir: value:", v) + v._flatten_ir(handles) + print("[debug]tuple _flatten_ir: handles:", handles) + + def __repr__(self): + return f"({' ,'.join(repr(x) for x in self.values)})" + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class dtype(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + # flagtree backend specialization + from ._utils import spec_get_primitive_bitwidth + get_primitive_bitwidth = spec_get_primitive_bitwidth + self.primitive_bitwidth = get_primitive_bitwidth(name) + self.itemsize = self.primitive_bitwidth // 8 + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 + + def mangle(self) -> str: + if self.is_int(): + SIGNED = dtype.SIGNEDNESS.SIGNED + prefix = 'i' if self.int_signedness == SIGNED else 'u' + return prefix + str(self.int_bitwidth) + if self.is_floating(): + return str(self) + if self.is_void(): + return 'V' + return super().mangle() + + def with_element_ty(self, element_ty: dtype): + assert not self.is_block() + return element_ty + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> Tuple[int]: + return self.shape + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return block_type(scalar_ty, self.shape) + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = '_'.join(map(str, self.shape)) + return f'{elt}S{shape}S' + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + value = tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}" + + +class tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, block_type: block_type): + """Not called by user code.""" + super().__init__(handle) + + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return descriptor_load(self, offsets, "", "", _builder) + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return descriptor_store(self, value, offsets, _builder) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + super()._flatten_ir_types(builder, out) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class tensor_descriptor(tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + # Global shape + self.shape = tuple(shape) + self.strides = tuple(strides) + self.type = tensor_descriptor_type( + block_type, + shape_type=self.shape.type, + strides_type=self.strides.type, + ) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py index 15c856d9a..1cea9a5e4 100644 --- a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py @@ -109,208 +109,3 @@ def ext_do_bench_npu(fn, warmup, rep, quantiles, return_mode): from triton.testing import _summarize_statistics avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) - - -def patch_triton_language(): - # Patch the triton language API here because triton's __init__.py - # import testing in the last stages. - from triton.language.tensor_descriptor import ( - tensor_descriptor, - tensor_descriptor_type, - ) - - from triton.language.core_ext import ( - dot, - cast, - gather, - get_element, - insert_slice, - extract_slice, - trans, - __add__, - __radd__, - __sub__, - __rsub__, - __mul__, - __rmul__, - __lshift__, - __rshift__, - parallel, - compile_hint, - make_tensor_descriptor, - load_tensor_descriptor, - store_tensor_descriptor, - multibuffer, - sync_block_all, - sync_block_set, - sync_block_wait, - dtype_to_ir, - sort - ) - from triton.language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 - from triton.language.math_ext import ( - umulhi, - exp, - exp2, - log, - log2, - cos, - sin, - sqrt, - sqrt_rn, - rsqrt, - div_rn, - erf, - tanh, - floor, - ceil, - _check_dtype, - fma, - ) - from triton.language.semantic_ext import ( - arange, - floordiv, - atom_red_typechecking_impl, - atomic_cas, - atomic_max, - atomic_min, - _load_legacy, - maximum, - minimum, - mod, - invert, - logical_and, - logical_or, - not_, - and_, - or_, - xor_, - minus, - dot_scaled, - ) - from triton import language - - language.cast = cast - language.dot = dot - language.flip = flip - language.sigmoid = sigmoid - language.softmax = softmax - language.gather = gather - language.insert_slice = insert_slice - language.extract_slice = extract_slice - language.get_element = get_element - language.tensor.__add__ = __add__ - language.tensor.__radd__ = __radd__ - language.tensor.__sub__ = __sub__ - language.tensor.__rsub__ = __rsub__ - language.tensor.__mul__ = __mul__ - language.tensor.__rmul__ = __rmul__ - language.tensor.__lshift__ = __lshift__ - language.tensor.__rshift__ = __rshift__ - language.trans = trans - language.parallel = parallel - language.compile_hint = compile_hint - language.sort = sort - language.multibuffer = multibuffer - language.sync_block_all = sync_block_all - language.sync_block_set = sync_block_set - language.sync_block_wait = sync_block_wait - language.make_tensor_descriptor = make_tensor_descriptor - language.tensor_descriptor = tensor_descriptor - language.tensor_descriptor_type = tensor_descriptor_type - language.load_tensor_descriptor = load_tensor_descriptor - language.store_tensor_descriptor = store_tensor_descriptor - - language.semantic.arange = arange - language.semantic.floordiv = floordiv - language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl - language.semantic.atomic_cas = atomic_cas - language.semantic.atomic_max = atomic_max - language.semantic.atomic_min = atomic_min - language.semantic._load_legacy = _load_legacy - language.semantic.maximum = maximum - language.semantic.minimum = minimum - language.semantic.invert = invert - language.semantic.logical_and = logical_and - language.semantic.logical_or = logical_or - language.semantic.mod = mod - language.semantic.not_ = not_ - language.semantic.and_ = and_ - language.semantic.or_ = or_ - language.semantic.xor_ = xor_ - language.semantic.minus = minus - language.semantic.dot_scaled = dot_scaled - - language.umulhi = umulhi - language.exp = exp - language.exp2 = exp2 - language.log = log - language.log2 = log2 - language.cos = cos - language.sin = sin - language.sqrt = sqrt - language.sqrt_rn = sqrt_rn - language.rsqrt = rsqrt - language.div_rn = div_rn - language.erf = erf - language.tanh = tanh - language.floor = floor - language.ceil = ceil - language.core.dtype.to_ir = dtype_to_ir - language.fma = fma - language.math.umulhi = umulhi - language.math.exp = exp - language.math.exp2 = exp2 - language.math.log = log - language.math.log2 = log2 - language.math.cos = cos - language.math.sin = sin - language.math.sqrt = sqrt - language.math.sqrt_rn = sqrt_rn - language.math.rsqrt = rsqrt - language.math.div_rn = div_rn - language.math.erf = erf - language.math.tanh = tanh - language.math.floor = floor - language.math.ceil = ceil - language.math._check_dtype = _check_dtype - language.math.fma = fma - language.math.isnan = language.extra.ascend.libdevice.isnan - language.math.isinf = language.extra.ascend.libdevice.isinf - language.math.reciprocal = language.extra.ascend.libdevice.reciprocal - language.math.log1p = language.extra.ascend.libdevice.log1p - language.math.relu = language.extra.ascend.libdevice.relu - language.math.tan = language.extra.ascend.libdevice.tan - language.math.atan = language.extra.ascend.libdevice.atan - language.math.tanh = language.extra.ascend.libdevice.tanh - language.math.ilogb = language.extra.ascend.libdevice.ilogb - language.math.ldexp = language.extra.ascend.libdevice.ldexp - language.math.pow = language.extra.ascend.libdevice.pow - language.math.flip = language.extra.ascend.libdevice.flip - language.math.atan2 = language.extra.ascend.libdevice.atan2 - language.math.div_rz = language.extra.ascend.libdevice.div_rz - language.math.fmod = language.extra.ascend.libdevice.fmod - language.math.trunc = language.extra.ascend.libdevice.trunc - language.math.round = language.extra.ascend.libdevice.round - language.math.finitef = finitef - language.math.isfinited = isfinited - language.math.rint = rint - language.math.atan2 = atan2 - language.extra.ascend.libdevice.umulhi = language.math.umulhi - language.extra.ascend.libdevice.exp = language.math.exp - language.extra.ascend.libdevice.exp2 = language.math.exp2 - language.extra.ascend.libdevice.log = language.math.log - language.extra.ascend.libdevice.log2 = language.math.log2 - language.extra.ascend.libdevice.cos = language.math.cos - language.extra.ascend.libdevice.sin = language.math.sin - language.extra.ascend.libdevice.sqrt = language.math.sqrt - language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn - language.extra.ascend.libdevice.rsqrt = language.math.rsqrt - language.extra.ascend.libdevice.div_rn = language.math.div_rn - language.extra.ascend.libdevice.erf = language.math.erf - language.extra.ascend.libdevice.tanh = language.math.tanh - language.extra.ascend.libdevice.floor = language.math.floor - language.extra.ascend.libdevice.ceil = language.math.ceil - language.extra.ascend.libdevice.fdiv = language.math.fdiv - language.extra.ascend.libdevice.fma = language.math.fma - language.extra.ascend.libdevice.abs = language.math.abs \ No newline at end of file